diff --git a/.github/workflows/build_and_run_chain_simulator_and_execute_system_test.yml b/.github/workflows/build_and_run_chain_simulator_and_execute_system_test.yml index 994bd217090..93b5425d9f4 100644 --- a/.github/workflows/build_and_run_chain_simulator_and_execute_system_test.yml +++ b/.github/workflows/build_and_run_chain_simulator_and_execute_system_test.yml @@ -62,20 +62,27 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | - // Get the latest comment + // Get all comments const comments = await github.rest.issues.listComments({ owner: context.repo.owner, repo: context.repo.repo, issue_number: context.issue.number, }); - const lastComment = comments.data.pop(); // Get the last comment + // Find the last comment that contains 'Run Tests:' + let lastTestComment = null; + for (let i = comments.data.length - 1; i >= 0; i--) { + if (comments.data[i].body.includes('Run Tests:')) { + lastTestComment = comments.data[i]; + break; + } + } - if (lastComment && lastComment.body.includes('Run Tests:')) { - const body = lastComment.body.trim(); + if (lastTestComment) { + const body = lastTestComment.body.trim(); core.setOutput('latest_comment', body); - // Parse the branches from the last comment + // Parse the branches from the last test comment const simulatorBranchMatch = body.match(/mx-chain-simulator-go:\s*(\S+)/); const testingSuiteBranchMatch = body.match(/mx-chain-testing-suite:\s*(\S+)/); @@ -86,8 +93,11 @@ jobs: if (testingSuiteBranchMatch) { core.exportVariable('MX_CHAIN_TESTING_SUITE_TARGET_BRANCH', testingSuiteBranchMatch[1]); } + + // Log which comment was used for configuration + core.info(`Found 'Run Tests:' comment from ${lastTestComment.user.login} at ${lastTestComment.created_at}`); } else { - core.info('The last comment does not contain "Run Tests:". Skipping branch override.'); + core.info('No comment containing "Run Tests:" was found. Using default branch settings.'); } @@ -146,6 +156,82 @@ jobs: go build echo "CHAIN_SIMULATOR_BUILD_PATH=$(pwd)" >> $GITHUB_ENV + - name: Initialize Chain Simulator + run: | + cd mx-chain-simulator-go/cmd/chainsimulator + + # Start ChainSimulator with minimal args to initialize configs + INIT_LOG_FILE="/tmp/chainsim_init.log" + echo "Starting ChainSimulator initialization process..." + ./chainsimulator > $INIT_LOG_FILE 2>&1 & + INIT_PROCESS_PID=$! + + # Verify the process is running + if ! ps -p $INIT_PROCESS_PID > /dev/null; then + echo "Failed to start ChainSimulator process" + cat $INIT_LOG_FILE + exit 1 + fi + + # Wait for the initialization to complete - look for multiple possible success patterns + INIT_COMPLETED=false + RETRY_COUNT=0 + MAX_RETRIES=60 # Increase timeout to 60 seconds + + echo "Waiting for ChainSimulator initialization..." + while [ $RETRY_COUNT -lt $MAX_RETRIES ]; do + # Check for any of these success patterns + if grep -q "starting as observer node" $INIT_LOG_FILE || \ + grep -q "ChainSimulator started successfully" $INIT_LOG_FILE || \ + grep -q "initialized the node" $INIT_LOG_FILE || \ + grep -q "Node is running" $INIT_LOG_FILE; then + INIT_COMPLETED=true + echo "ChainSimulator initialization completed successfully" + break + fi + + # If there's a known fatal error, exit early + if grep -q "fatal error" $INIT_LOG_FILE || grep -q "panic:" $INIT_LOG_FILE; then + echo "Fatal error detected during initialization:" + grep -A 10 -E "fatal error|panic:" $INIT_LOG_FILE + break + fi + + # Print progress every 10 seconds + if [ $((RETRY_COUNT % 10)) -eq 0 ]; then + echo "Still waiting for initialization... ($RETRY_COUNT seconds elapsed)" + tail -5 $INIT_LOG_FILE + fi + + RETRY_COUNT=$((RETRY_COUNT+1)) + sleep 1 + done + + # Kill the initialization process - try graceful shutdown first + echo "Stopping initialization process (PID: $INIT_PROCESS_PID)..." + kill -TERM $INIT_PROCESS_PID 2>/dev/null || true + sleep 3 + + # Check if process still exists and force kill if needed + if ps -p $INIT_PROCESS_PID > /dev/null 2>&1; then + echo "Process still running, forcing kill..." + kill -9 $INIT_PROCESS_PID 2>/dev/null || true + sleep 1 + fi + + if [ "$INIT_COMPLETED" != "true" ]; then + echo "ChainSimulator initialization failed after $MAX_RETRIES seconds" + echo "Last 20 lines of log:" + tail -20 $INIT_LOG_FILE + exit 1 + fi + + # Create a marker file to indicate successful initialization + touch /tmp/chain_simulator_initialized.lock + echo "Chain Simulator successfully initialized" + + echo "Initialization log stored at: $INIT_LOG_FILE" + - name: Checkout mx-chain-testing-suite uses: actions/checkout@v4 with: diff --git a/api/errors/errors.go b/api/errors/errors.go index 3f4e495b9d2..88ebeeec1c2 100644 --- a/api/errors/errors.go +++ b/api/errors/errors.go @@ -28,6 +28,9 @@ var ErrGetValueForKey = errors.New("get value for key error") // ErrGetKeyValuePairs signals an error in getting the key-value pairs of a key for an account var ErrGetKeyValuePairs = errors.New("get key-value pairs error") +// ErrIterateKeys signals an error in iterating over the keys of an account +var ErrIterateKeys = errors.New("iterate keys error") + // ErrGetESDTBalance signals an error in getting esdt balance for given address var ErrGetESDTBalance = errors.New("get esdt balance for account error") @@ -43,6 +46,12 @@ var ErrGetESDTNFTData = errors.New("get esdt nft data for account error") // ErrEmptyAddress signals that an empty address was provided var ErrEmptyAddress = errors.New("address is empty") +// ErrEmptyNumKeys signals that an empty numKeys was provided +var ErrEmptyNumKeys = errors.New("numKeys is empty") + +// ErrEmptyCheckpointId signals that an empty checkpointId was provided +var ErrEmptyCheckpointId = errors.New("checkpointId is empty") + // ErrEmptyKey signals that an empty key was provided var ErrEmptyKey = errors.New("key is empty") diff --git a/api/groups/addressGroup.go b/api/groups/addressGroup.go index 151b7f53372..a9a15957328 100644 --- a/api/groups/addressGroup.go +++ b/api/groups/addressGroup.go @@ -32,6 +32,7 @@ const ( getRegisteredNFTsPath = "/:address/registered-nfts" getESDTNFTDataPath = "/:address/nft/:tokenIdentifier/nonce/:nonce" getGuardianData = "/:address/guardian-data" + iterateKeysPath = "/iterate-keys" urlParamOnFinalBlock = "onFinalBlock" urlParamOnStartOfEpoch = "onStartOfEpoch" urlParamBlockNonce = "blockNonce" @@ -55,6 +56,7 @@ type addressFacadeHandler interface { GetESDTsWithRole(address string, role string, options api.AccountQueryOptions) ([]string, api.BlockInfo, error) GetAllESDTTokens(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) GetKeyValuePairs(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) + IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) IsDataTrieMigrated(address string, options api.AccountQueryOptions) (bool, error) IsInterfaceNil() bool @@ -134,6 +136,11 @@ func NewAddressGroup(facade addressFacadeHandler) (*addressGroup, error) { Method: http.MethodGet, Handler: ag.getKeyValuePairs, }, + { + Path: iterateKeysPath, + Method: http.MethodPost, + Handler: ag.iterateKeys, + }, { Path: getESDTBalancePath, Method: http.MethodGet, @@ -327,7 +334,7 @@ func (ag *addressGroup) getGuardianData(c *gin.Context) { shared.RespondWithSuccess(c, gin.H{"guardianData": guardianData, "blockInfo": blockInfo}) } -// addressGroup returns all the key-value pairs for the given address +// getKeyValuePairs returns all the key-value pairs for the given address func (ag *addressGroup) getKeyValuePairs(c *gin.Context) { addr, options, err := extractBaseParams(c) if err != nil { @@ -344,6 +351,47 @@ func (ag *addressGroup) getKeyValuePairs(c *gin.Context) { shared.RespondWithSuccess(c, gin.H{"pairs": value, "blockInfo": blockInfo}) } +// IterateKeysRequest defines the request structure for iterating keys +type IterateKeysRequest struct { + Address string `json:"address"` + NumKeys uint `json:"numKeys"` + IteratorState [][]byte `json:"iteratorState"` +} + +// iterateKeys iterates keys for the given address +func (ag *addressGroup) iterateKeys(c *gin.Context) { + var iterateKeysRequest = &IterateKeysRequest{} + err := c.ShouldBindJSON(&iterateKeysRequest) + if err != nil { + shared.RespondWithValidationError(c, errors.ErrValidation, err) + return + } + + if len(iterateKeysRequest.Address) == 0 { + shared.RespondWithValidationError(c, errors.ErrValidation, errors.ErrEmptyAddress) + return + } + + options, err := extractAccountQueryOptions(c) + if err != nil { + shared.RespondWithValidationError(c, errors.ErrIterateKeys, err) + return + } + + value, newIteratorState, blockInfo, err := ag.getFacade().IterateKeys( + iterateKeysRequest.Address, + iterateKeysRequest.NumKeys, + iterateKeysRequest.IteratorState, + options, + ) + if err != nil { + shared.RespondWithInternalError(c, errors.ErrIterateKeys, err) + return + } + + shared.RespondWithSuccess(c, gin.H{"pairs": value, "newIteratorState": newIteratorState, "blockInfo": blockInfo}) +} + // getESDTBalance returns the balance for the given address and esdt token func (ag *addressGroup) getESDTBalance(c *gin.Context) { addr, tokenIdentifier, options, err := extractGetESDTBalanceParams(c) diff --git a/api/groups/addressGroup_test.go b/api/groups/addressGroup_test.go index bb19bb81d2c..b98dd225082 100644 --- a/api/groups/addressGroup_test.go +++ b/api/groups/addressGroup_test.go @@ -112,7 +112,7 @@ type esdtTokensCompleteResponseData struct { type esdtTokensCompleteResponse struct { Data esdtTokensCompleteResponseData `json:"data"` Error string `json:"error"` - Code string + Code string `json:"code"` } type keyValuePairsResponseData struct { @@ -122,7 +122,17 @@ type keyValuePairsResponseData struct { type keyValuePairsResponse struct { Data keyValuePairsResponseData `json:"data"` Error string `json:"error"` - Code string + Code string `json:"code"` +} + +type iterateKeysResponseData struct { + Pairs map[string]string `json:"pairs"` + NewIteratorState [][]byte `json:"newIteratorState"` +} +type iterateKeysResponse struct { + Data iterateKeysResponseData `json:"data"` + Error string `json:"error"` + Code string `json:"code"` } type esdtRolesResponseData struct { @@ -132,7 +142,7 @@ type esdtRolesResponseData struct { type esdtRolesResponse struct { Data esdtRolesResponseData `json:"data"` Error string `json:"error"` - Code string + Code string `json:"code"` } type usernameResponseData struct { @@ -662,6 +672,106 @@ func TestAddressGroup_getKeyValuePairs(t *testing.T) { }) } +func TestAddressGroup_iterateKeys(t *testing.T) { + t.Parallel() + + t.Run("invalid body should error", + testErrorScenario("/address/iterate-keys", "POST", bytes.NewBuffer([]byte("invalid body")), + formatExpectedErr(apiErrors.ErrValidation, errors.New("invalid character 'i' looking for beginning of value")))) + t.Run("empty address should error", func(t *testing.T) { + t.Parallel() + + body := &groups.IterateKeysRequest{ + Address: "", + } + bodyBytes, _ := json.Marshal(body) + testAddressGroup( + t, + &mock.FacadeStub{}, + "/address/iterate-keys", + "POST", + bytes.NewBuffer(bodyBytes), + http.StatusBadRequest, + formatExpectedErr(apiErrors.ErrValidation, apiErrors.ErrEmptyAddress), + ) + }) + t.Run("invalid query options should error", func(t *testing.T) { + t.Parallel() + + body := &groups.IterateKeysRequest{ + Address: "erd1", + } + bodyBytes, _ := json.Marshal(body) + testAddressGroup( + t, + &mock.FacadeStub{}, + "/address/iterate-keys?blockNonce=not-uint64", + "POST", + bytes.NewBuffer(bodyBytes), + http.StatusBadRequest, + formatExpectedErr(apiErrors.ErrIterateKeys, apiErrors.ErrBadUrlParams), + ) + }) + t.Run("with node fail should err", func(t *testing.T) { + t.Parallel() + + body := &groups.IterateKeysRequest{ + Address: "erd1", + } + bodyBytes, _ := json.Marshal(body) + facade := &mock.FacadeStub{ + IterateKeysCalled: func(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) { + return nil, nil, api.BlockInfo{}, expectedErr + }, + } + testAddressGroup( + t, + facade, + "/address/iterate-keys", + "POST", + bytes.NewBuffer(bodyBytes), + http.StatusInternalServerError, + formatExpectedErr(apiErrors.ErrIterateKeys, expectedErr), + ) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + pairs := map[string]string{ + "k1": "v1", + "k2": "v2", + } + + body := &groups.IterateKeysRequest{ + Address: "erd1", + NumKeys: 10, + IteratorState: [][]byte{[]byte("starting"), []byte("state")}, + } + newIteratorState := [][]byte{[]byte("new"), []byte("state")} + bodyBytes, _ := json.Marshal(body) + facade := &mock.FacadeStub{ + IterateKeysCalled: func(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) { + assert.Equal(t, body.Address, address) + assert.Equal(t, body.NumKeys, numKeys) + assert.Equal(t, body.IteratorState, iteratorState) + return pairs, newIteratorState, api.BlockInfo{}, nil + }, + } + + response := &iterateKeysResponse{} + loadAddressGroupResponse( + t, + facade, + "/address/iterate-keys", + "POST", + bytes.NewBuffer(bodyBytes), + response, + ) + assert.Equal(t, pairs, response.Data.Pairs) + assert.Equal(t, newIteratorState, response.Data.NewIteratorState) + }) +} + func TestAddressGroup_getESDTBalance(t *testing.T) { t.Parallel() @@ -1143,6 +1253,7 @@ func getAddressRoutesConfig() config.ApiRoutesConfig { {Name: "/:address/username", Open: true}, {Name: "/:address/code-hash", Open: true}, {Name: "/:address/keys", Open: true}, + {Name: "/iterate-keys", Open: true}, {Name: "/:address/key/:key", Open: true}, {Name: "/:address/esdt", Open: true}, {Name: "/:address/esdts/roles", Open: true}, diff --git a/api/mock/facadeStub.go b/api/mock/facadeStub.go index 62de2febc81..94bc0551c76 100644 --- a/api/mock/facadeStub.go +++ b/api/mock/facadeStub.go @@ -49,6 +49,7 @@ type FacadeStub struct { GetUsernameCalled func(address string, options api.AccountQueryOptions) (string, api.BlockInfo, error) GetCodeHashCalled func(address string, options api.AccountQueryOptions) ([]byte, api.BlockInfo, error) GetKeyValuePairsCalled func(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) + IterateKeysCalled func(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) SimulateTransactionExecutionHandler func(tx *transaction.Transaction) (*txSimData.SimulationResultsWithVMOutput, error) GetESDTDataCalled func(address string, key string, nonce uint64, options api.AccountQueryOptions) (*esdt.ESDigitalToken, api.BlockInfo, error) GetAllESDTTokensCalled func(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) @@ -241,6 +242,15 @@ func (f *FacadeStub) GetKeyValuePairs(address string, options api.AccountQueryOp return nil, api.BlockInfo{}, nil } +// IterateKeys - +func (f *FacadeStub) IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) { + if f.IterateKeysCalled != nil { + return f.IterateKeysCalled(address, numKeys, iteratorState, options) + } + + return nil, nil, api.BlockInfo{}, nil +} + // GetGuardianData - func (f *FacadeStub) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) { if f.GetGuardianDataCalled != nil { diff --git a/api/shared/interface.go b/api/shared/interface.go index 206cea6ee30..adedd6642af 100644 --- a/api/shared/interface.go +++ b/api/shared/interface.go @@ -74,6 +74,7 @@ type FacadeHandler interface { GetESDTsWithRole(address string, role string, options api.AccountQueryOptions) ([]string, api.BlockInfo, error) GetAllESDTTokens(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) GetKeyValuePairs(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) + IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) GetBlockByHash(hash string, options api.BlockQueryOptions) (*api.Block, error) GetBlockByNonce(nonce uint64, options api.BlockQueryOptions) (*api.Block, error) diff --git a/cmd/node/config/api.toml b/cmd/node/config/api.toml index fcf9cf7fc0b..24f641f78c9 100644 --- a/cmd/node/config/api.toml +++ b/cmd/node/config/api.toml @@ -79,6 +79,9 @@ # /address/:address/keys will return all the key-value pairs of a given account { Name = "/:address/keys", Open = true }, + # address/iterate-keys will return the given num of key-value pairs for the given account. The iteration will start from the given starting state + { Name = "/iterate-keys", Open = true }, + # /address/:address/key/:key will return the value of a key for a given account { Name = "/:address/key/:key", Open = true }, diff --git a/cmd/node/config/config.toml b/cmd/node/config/config.toml index 4da40dfde30..5c6da7db57a 100644 --- a/cmd/node/config/config.toml +++ b/cmd/node/config/config.toml @@ -40,6 +40,18 @@ # Make sure that this is greater than the unbonding period! SetGuardianEpochsDelay = 2 # TODO: for mainnet should be 20, 2 is just for testing + # ChainParametersByEpoch defines chain operation configurable values that can be modified based on epochs + ChainParametersByEpoch = [ + { EnableEpoch = 0, RoundDuration = 6000, ShardConsensusGroupSize = 7, ShardMinNumNodes = 10, MetachainConsensusGroupSize = 10, MetachainMinNumNodes = 10, Hysteresis = 0.2, Adaptivity = false }, + { EnableEpoch = 1, RoundDuration = 6000, ShardConsensusGroupSize = 10, ShardMinNumNodes = 10, MetachainConsensusGroupSize = 10, MetachainMinNumNodes = 10, Hysteresis = 0.2, Adaptivity = false } + ] + + # EpochChangeGracePeriodEnableEpoch represents the configuration of different grace periods for epoch change with their activation epochs + EpochChangeGracePeriodByEpoch = [ + {EnableEpoch = 0, GracePeriodInRounds = 1 }, + {EnableEpoch = 1, GracePeriodInRounds = 10 }, # Andromeda epoch comes with a longer grace period + ] + [HardwareRequirements] CPUFlags = ["SSE4", "SSE42"] @@ -186,6 +198,19 @@ MaxBatchSize = 100 MaxOpenFiles = 10 +[ProofsStorage] + [ProofsStorage.Cache] + Name = "ProofsStorage" + Capacity = 1000 + Type = "SizeLRU" + SizeInBytes = 20971520 #20MB + [ProofsStorage.DB] + FilePath = "Proofs" + Type = "LvlDBSerial" + BatchDelaySeconds = 2 + MaxBatchSize = 100 + MaxOpenFiles = 10 + [TxStorage] [TxStorage.Cache] Name = "TxStorage" @@ -365,6 +390,10 @@ MaxHeadersPerShard = 1000 NumElementsToRemoveOnEviction = 200 +[ProofsPoolConfig] + CleanupNonceDelta = 3 + BucketSize = 100 + [BadBlocksCache] Name = "BadBlocksCache" Capacity = 1000 @@ -497,7 +526,7 @@ IntervalInSeconds = 1 ReservedPercent = 20.0 [Antiflood.FastReacting.PeerMaxInput] - BaseMessagesPerInterval = 140 + BaseMessagesPerInterval = 280 TotalSizePerInterval = 4194304 #4MB/s [Antiflood.FastReacting.PeerMaxInput.IncreaseFactor] Threshold = 10 #if consensus size will exceed this value, then @@ -539,7 +568,7 @@ PeerBanDurationInSeconds = 3600 [Antiflood.PeerMaxOutput] - BaseMessagesPerInterval = 75 + BaseMessagesPerInterval = 150 TotalSizePerInterval = 2097152 #2MB/s [Antiflood.Cache] @@ -663,6 +692,10 @@ MaxPeerTrieLevelInMemory = 5 StateStatisticsEnabled = false +[TrieLeavesRetrieverConfig] + Enabled = false + MaxSizeInBytes = 10485760 #10MB + [BlockSizeThrottleConfig] MinSizeInBytes = 104857 # 104857 is 10% from 1MB MaxSizeInBytes = 943718 # 943718 is 90% from 1MB @@ -672,8 +705,7 @@ TimeOutForSCExecutionInMilliseconds = 10000 # 10 seconds = 10000 milliseconds WasmerSIGSEGVPassthrough = false # must be false for release WasmVMVersions = [ - { StartEpoch = 0, Version = "v1.4" }, - { StartEpoch = 1, Version = "v1.5" }, # TODO: set also the RoundActivations.DisableAsyncCallV1 accordingly + { StartEpoch = 0, Version = "v1.5" }, ] TransferAndExecuteByUserAddresses = [ # TODO: set real contract addresses for all shards "erd1qqqqqqqqqqqqqpgqr46jrxr6r2unaqh75ugd308dwx5vgnhwh47qtvepe3", #shard 0 @@ -684,8 +716,7 @@ TimeOutForSCExecutionInMilliseconds = 10000 # 10 seconds = 10000 milliseconds WasmerSIGSEGVPassthrough = false # must be false for release WasmVMVersions = [ - { StartEpoch = 0, Version = "v1.4" }, - { StartEpoch = 1, Version = "v1.5" }, # TODO: set also the RoundActivations.DisableAsyncCallV1 accordingly + { StartEpoch = 0, Version = "v1.5" }, ] TransferAndExecuteByUserAddresses = [ # TODO: set real contract addresses for all shards "erd1qqqqqqqqqqqqqpgqr46jrxr6r2unaqh75ugd308dwx5vgnhwh47qtvepe3", @@ -946,3 +977,7 @@ # MaxRoundsOfInactivityAccepted defines the number of rounds missed by a main or higher level backup machine before # the current machine will take over and propose/sign blocks. Used in both single-key and multi-key modes. MaxRoundsOfInactivityAccepted = 3 + +[InterceptedDataVerifier] + CacheSpanInSec = 30 + CacheExpiryInSec = 30 diff --git a/cmd/node/config/enableEpochs.toml b/cmd/node/config/enableEpochs.toml index fb37961adde..1956786ee7d 100644 --- a/cmd/node/config/enableEpochs.toml +++ b/cmd/node/config/enableEpochs.toml @@ -6,7 +6,7 @@ BuiltInFunctionsEnableEpoch = 0 # RelayedTransactionsEnableEpoch represents the epoch when the relayed transactions will be enabled - RelayedTransactionsEnableEpoch = 1 + RelayedTransactionsEnableEpoch = 0 # PenalizedTooMuchGasEnableEpoch represents the epoch when the penalization for using too much gas will be enabled PenalizedTooMuchGasEnableEpoch = 0 @@ -82,7 +82,7 @@ CorrectLastUnjailedEnableEpoch = 1 # RelayedTransactionsV2EnableEpoch represents the epoch when the relayed transactions V2 will be enabled - RelayedTransactionsV2EnableEpoch = 1 + RelayedTransactionsV2EnableEpoch = 0 # UnbondTokensV2EnableEpoch represents the epoch when the new implementation of the unbond tokens function is available UnbondTokensV2EnableEpoch = 1 @@ -345,6 +345,12 @@ # RelayedTransactionsV3FixESDTTransferEnableEpoch represents the epoch when the fix for relayed transactions v3 with esdt transfer will be enabled RelayedTransactionsV3FixESDTTransferEnableEpoch = 1 # TODO: keep this equal to RelayedTransactionsV3EnableEpoch for mainnet + # AndromedaEnableEpoch represents the epoch when the equivalent messages and fix order for consensus features are enabled + AndromedaEnableEpoch = 1 + + # CheckBuiltInCallOnTransferValueAndFailEnableRound represents the ROUND when the check on transfer value fix is activated + CheckBuiltInCallOnTransferValueAndFailEnableRound = 1 + # MaskVMInternalDependenciesErrorsEnableEpoch represents the epoch when the additional internal erorr masking in vm is enabled MaskVMInternalDependenciesErrorsEnableEpoch = 2 @@ -380,5 +386,5 @@ [GasSchedule] # GasScheduleByEpochs holds the configuration for the gas schedule that will be applied from specific epochs GasScheduleByEpochs = [ - { StartEpoch = 0, FileName = "gasScheduleV8.toml" }, + { StartEpoch = 0, FileName = "gasScheduleV9.toml" }, ] diff --git a/cmd/node/config/enableRounds.toml b/cmd/node/config/enableRounds.toml index d7be75bb524..6258483fb13 100644 --- a/cmd/node/config/enableRounds.toml +++ b/cmd/node/config/enableRounds.toml @@ -10,4 +10,4 @@ [RoundActivations] [RoundActivations.DisableAsyncCallV1] Options = [] - Round = "100" + Round = "0" diff --git a/cmd/node/config/gasSchedules/gasScheduleV9.toml b/cmd/node/config/gasSchedules/gasScheduleV9.toml new file mode 100644 index 00000000000..aa437024c7b --- /dev/null +++ b/cmd/node/config/gasSchedules/gasScheduleV9.toml @@ -0,0 +1,856 @@ +[BuiltInCost] + ChangeOwnerAddress = 5000000 + ClaimDeveloperRewards = 5000000 + SaveUserName = 1000000 + SaveKeyValue = 100000 + ESDTTransfer = 200000 + ESDTBurn = 100000 + ESDTLocalMint = 50000 + ESDTLocalBurn = 50000 + ESDTNFTCreate = 150000 + ESDTNFTAddQuantity = 50000 + ESDTNFTBurn = 50000 + ESDTNFTTransfer = 200000 + ESDTNFTChangeCreateOwner = 1000000 + ESDTNFTAddUri = 50000 + ESDTNFTUpdateAttributes = 50000 + ESDTNFTMultiTransfer = 200000 + MultiESDTNFTTransfer = 200000 # should be the same value with the ESDTNFTMultiTransfer + ESDTModifyRoyalties = 500000 + ESDTModifyCreator = 500000 + ESDTNFTRecreate = 1000000 + ESDTNFTUpdate = 1000000 + ESDTNFTSetNewURIs = 500000 + SetGuardian = 250000 + GuardAccount = 250000 + UnGuardAccount = 250000 + TrieLoadPerNode = 100000 + TrieStorePerNode = 50000 + +[MetaChainSystemSCsCost] + Stake = 5000000 + UnStake = 5000000 + UnBond = 5000000 + Claim = 5000000 + Get = 5000000 + ChangeRewardAddress = 5000000 + ChangeValidatorKeys = 5000000 + UnJail = 5000000 + DelegationOps = 1000000 + DelegationMgrOps = 50000000 + ValidatorToDelegation = 500000000 + ESDTIssue = 50000000 + ESDTOperations = 50000000 + Proposal = 50000000 + Vote = 5000000 + DelegateVote = 50000000 + RevokeVote = 50000000 + CloseProposal = 50000000 + ClearProposal = 50000000 + ClaimAccumulatedFees = 1000000 + ChangeConfig = 50000000 + GetAllNodeStates = 20000000 + UnstakeTokens = 5000000 + UnbondTokens = 5000000 + GetActiveFund = 50000 + FixWaitingListSize = 500000000 + +[BaseOperationCost] + StorePerByte = 10000 + ReleasePerByte = 1000 + DataCopyPerByte = 50 + PersistPerByte = 1000 + CompilePerByte = 300 + AoTPreparePerByte = 100 + GetCode = 1000000 + +[BaseOpsAPICost] + GetSCAddress = 1000 + GetOwnerAddress = 5000 + IsSmartContract = 5000 + GetShardOfAddress = 5000 + GetExternalBalance = 7000 + GetBlockHash = 10000 + TransferValue = 100000 + GetArgument = 1000 + GetFunction = 1000 + GetNumArguments = 1000 + StorageStore = 75000 + StorageLoad = 50000 + CachedStorageLoad = 1000 + GetCaller = 1000 + GetCallValue = 1000 + Log = 3750 + Finish = 1 + SignalError = 1 + GetBlockTimeStamp = 10000 + GetGasLeft = 1000 + Int64GetArgument = 1000 + Int64StorageStore = 75000 + Int64StorageLoad = 50000 + Int64Finish = 1000 + GetStateRootHash = 10000 + GetBlockNonce = 10000 + GetBlockEpoch = 10000 + GetBlockRound = 10000 + GetBlockRandomSeed = 10000 + ExecuteOnSameContext = 100000 + ExecuteOnDestContext = 100000 + DelegateExecution = 100000 + AsyncCallStep = 100000 + AsyncCallbackGasLock = 4000000 + ExecuteReadOnly = 160000 + CreateContract = 300000 + GetReturnData = 1000 + GetNumReturnData = 1000 + GetReturnDataSize = 1000 + GetOriginalTxHash = 10000 + CleanReturnData = 1000 + DeleteFromReturnData = 1000 + GetPrevTxHash = 10000 + GetCurrentTxHash = 10000 + CreateAsyncCall = 200000 + SetAsyncCallback = 100000 + SetAsyncGroupCallback = 100000 + SetAsyncContextCallback = 100000 + GetCallbackClosure = 10000 + GetCodeMetadata = 10000 + GetCodeHash = 10000 + IsBuiltinFunction = 10000 + IsReservedFunctionName = 10000 + GetRoundTime = 10000 + EpochStartBlockTimeStamp = 10000 + EpochStartBlockNonce = 10000 + EpochStartBlockRound = 10000 + +[EthAPICost] + UseGas = 100 + GetAddress = 100000 + GetExternalBalance = 70000 + GetBlockHash = 100000 + Call = 160000 + CallDataCopy = 200 + GetCallDataSize = 100 + CallCode = 160000 + CallDelegate = 160000 + CallStatic = 160000 + StorageStore = 250000 + StorageLoad = 100000 + GetCaller = 100 + GetCallValue = 100 + CodeCopy = 1000 + GetCodeSize = 100 + GetBlockCoinbase = 100 + Create = 320000 + GetBlockDifficulty = 100 + ExternalCodeCopy = 3000 + GetExternalCodeSize = 2500 + GetGasLeft = 100 + GetBlockGasLimit = 100000 + GetTxGasPrice = 1000 + Log = 3750 + GetBlockNumber = 100000 + GetTxOrigin = 100000 + Finish = 1 + Revert = 1 + GetReturnDataSize = 200 + ReturnDataCopy = 500 + SelfDestruct = 5000000 + GetBlockTimeStamp = 100000 + +[BigIntAPICost] + BigIntNew = 2000 + BigIntByteLength = 2000 + BigIntUnsignedByteLength = 2000 + BigIntSignedByteLength = 2000 + BigIntGetBytes = 2000 + BigIntGetUnsignedBytes = 2000 + BigIntGetSignedBytes = 2000 + BigIntSetBytes = 2000 + BigIntSetUnsignedBytes = 2000 + BigIntSetSignedBytes = 2000 + BigIntIsInt64 = 2000 + BigIntGetInt64 = 2000 + BigIntSetInt64 = 2000 + BigIntAdd = 2000 + BigIntSub = 2000 + BigIntMul = 6000 + BigIntSqrt = 6000 + BigIntPow = 6000 + BigIntLog = 6000 + BigIntTDiv = 6000 + BigIntTMod = 6000 + BigIntEDiv = 6000 + BigIntEMod = 6000 + BigIntAbs = 2000 + BigIntNeg = 2000 + BigIntSign = 2000 + BigIntCmp = 2000 + BigIntNot = 2000 + BigIntAnd = 2000 + BigIntOr = 2000 + BigIntXor = 2000 + BigIntShr = 2000 + BigIntShl = 2000 + BigIntFinishUnsigned = 1000 + BigIntFinishSigned = 1000 + BigIntStorageLoadUnsigned = 50000 + BigIntStorageStoreUnsigned = 75000 + BigIntGetArgument = 1000 + BigIntGetUnsignedArgument = 1000 + BigIntGetSignedArgument = 1000 + BigIntGetCallValue = 1000 + BigIntGetExternalBalance = 10000 + CopyPerByteForTooBig = 1000 + +[CryptoAPICost] + SHA256 = 1000000 + Keccak256 = 1000000 + Ripemd160 = 1000000 + VerifyBLS = 5000000 + VerifyEd25519 = 2000000 + VerifySecp256k1 = 2000000 + EllipticCurveNew = 10000 + AddECC = 75000 + DoubleECC = 65000 + IsOnCurveECC = 10000 + ScalarMultECC = 400000 + MarshalECC = 13000 + MarshalCompressedECC = 15000 + UnmarshalECC = 20000 + UnmarshalCompressedECC = 270000 + GenerateKeyECC = 7000000 + EncodeDERSig = 10000000 + VerifySecp256r1 = 2000000 + VerifyBLSSignatureShare = 2000000 + VerifyBLSMultiSig = 2000000 + +[ManagedBufferAPICost] + MBufferNew = 2000 + MBufferNewFromBytes = 2000 + MBufferGetLength = 2000 + MBufferGetBytes = 2000 + MBufferGetByteSlice = 2000 + MBufferCopyByteSlice = 2000 + MBufferSetBytes = 2000 + MBufferAppend = 2000 + MBufferAppendBytes = 2000 + MBufferToBigIntUnsigned = 2000 + MBufferToBigIntSigned = 5000 + MBufferFromBigIntUnsigned = 2000 + MBufferFromBigIntSigned = 5000 + MBufferStorageStore = 75000 + MBufferStorageLoad = 50000 + MBufferGetArgument = 1000 + MBufferFinish = 1000 + MBufferSetRandom = 6000 + MBufferToBigFloat = 2000 + MBufferFromBigFloat = 2000 + MBufferToSmallIntUnsigned = 10000 + MBufferToSmallIntSigned = 10000 + MBufferFromSmallIntUnsigned = 10000 + MBufferFromSmallIntSigned = 10000 + +[BigFloatAPICost] + BigFloatNewFromParts = 3000 + BigFloatAdd = 7000 + BigFloatSub = 7000 + BigFloatMul = 7000 + BigFloatDiv = 7000 + BigFloatTruncate = 5000 + BigFloatNeg = 5000 + BigFloatClone = 5000 + BigFloatCmp = 4000 + BigFloatAbs = 5000 + BigFloatSqrt = 7000 + BigFloatPow = 10000 + BigFloatFloor = 5000 + BigFloatCeil = 5000 + BigFloatIsInt = 3000 + BigFloatSetBigInt = 3000 + BigFloatSetInt64 = 1000 + BigFloatGetConst = 1000 + +[ManagedMapAPICost] + ManagedMapNew = 10000 + ManagedMapPut = 10000 + ManagedMapGet = 10000 + ManagedMapRemove = 10000 + ManagedMapContains = 10000 + +[WASMOpcodeCost] + Unreachable = 5 + Nop = 5 + Block = 5 + Loop = 5 + If = 5 + Else = 5 + End = 5 + Br = 5 + BrIf = 5 + BrTable = 5 + Return = 5 + Call = 5 + CallIndirect = 5 + Drop = 5 + Select = 5 + TypedSelect = 5 + LocalGet = 5 + LocalSet = 5 + LocalTee = 5 + GlobalGet = 5 + GlobalSet = 5 + I32Load = 5 + I64Load = 5 + F32Load = 6 + F64Load = 6 + I32Load8S = 5 + I32Load8U = 5 + I32Load16S = 5 + I32Load16U = 5 + I64Load8S = 5 + I64Load8U = 5 + I64Load16S = 5 + I64Load16U = 5 + I64Load32S = 5 + I64Load32U = 5 + I32Store = 5 + I64Store = 5 + F32Store = 12 + F64Store = 12 + I32Store8 = 5 + I32Store16 = 5 + I64Store8 = 5 + I64Store16 = 5 + I64Store32 = 5 + MemorySize = 5 + MemoryGrow = 1000000 + I32Const = 5 + I64Const = 5 + F32Const = 5 + F64Const = 5 + RefNull = 5 + RefIsNull = 5 + RefFunc = 5 + I32Eqz = 5 + I32Eq = 5 + I32Ne = 5 + I32LtS = 5 + I32LtU = 5 + I32GtS = 5 + I32GtU = 5 + I32LeS = 5 + I32LeU = 5 + I32GeS = 5 + I32GeU = 5 + I64Eqz = 5 + I64Eq = 5 + I64Ne = 5 + I64LtS = 5 + I64LtU = 5 + I64GtS = 5 + I64GtU = 5 + I64LeS = 5 + I64LeU = 5 + I64GeS = 5 + I64GeU = 5 + F32Eq = 6 + F32Ne = 6 + F32Lt = 6 + F32Gt = 6 + F32Le = 6 + F32Ge = 6 + F64Eq = 6 + F64Ne = 6 + F64Lt = 6 + F64Gt = 6 + F64Le = 6 + F64Ge = 6 + I32Clz = 100 + I32Ctz = 100 + I32Popcnt = 100 + I32Add = 5 + I32Sub = 5 + I32Mul = 5 + I32DivS = 18 + I32DivU = 18 + I32RemS = 18 + I32RemU = 18 + I32And = 5 + I32Or = 5 + I32Xor = 5 + I32Shl = 5 + I32ShrS = 5 + I32ShrU = 5 + I32Rotl = 5 + I32Rotr = 5 + I64Clz = 100 + I64Ctz = 100 + I64Popcnt = 100 + I64Add = 5 + I64Sub = 5 + I64Mul = 5 + I64DivS = 18 + I64DivU = 18 + I64RemS = 18 + I64RemU = 18 + I64And = 5 + I64Or = 5 + I64Xor = 5 + I64Shl = 5 + I64ShrS = 5 + I64ShrU = 5 + I64Rotl = 5 + I64Rotr = 5 + F32Abs = 5 + F32Neg = 5 + F32Ceil = 100 + F32Floor = 100 + F32Trunc = 100 + F32Nearest = 100 + F32Sqrt = 100 + F32Add = 5 + F32Sub = 5 + F32Mul = 15 + F32Div = 100 + F32Min = 15 + F32Max = 15 + F32Copysign = 5 + F64Abs = 5 + F64Neg = 5 + F64Ceil = 100 + F64Floor = 100 + F64Trunc = 100 + F64Nearest = 100 + F64Sqrt = 100 + F64Add = 5 + F64Sub = 5 + F64Mul = 15 + F64Div = 100 + F64Min = 15 + F64Max = 15 + F64Copysign = 5 + I32WrapI64 = 9 + I32TruncF32S = 100 + I32TruncF32U = 100 + I32TruncF64S = 100 + I32TruncF64U = 100 + I64ExtendI32S = 9 + I64ExtendI32U = 9 + I64TruncF32S = 100 + I64TruncF32U = 100 + I64TruncF64S = 100 + I64TruncF64U = 100 + F32ConvertI32S = 100 + F32ConvertI32U = 100 + F32ConvertI64S = 100 + F32ConvertI64U = 100 + F32DemoteF64 = 100 + F64ConvertI32S = 100 + F64ConvertI32U = 100 + F64ConvertI64S = 100 + F64ConvertI64U = 100 + F64PromoteF32 = 100 + I32ReinterpretF32 = 100 + I64ReinterpretF64 = 100 + F32ReinterpretI32 = 100 + F64ReinterpretI64 = 100 + I32Extend8S = 9 + I32Extend16S = 9 + I64Extend8S = 9 + I64Extend16S = 9 + I64Extend32S = 9 + I32TruncSatF32S = 100 + I32TruncSatF32U = 100 + I32TruncSatF64S = 100 + I32TruncSatF64U = 100 + I64TruncSatF32S = 100 + I64TruncSatF32U = 100 + I64TruncSatF64S = 100 + I64TruncSatF64U = 100 + MemoryInit = 5 + DataDrop = 5 + MemoryCopy = 5 + MemoryFill = 5 + TableInit = 10 + ElemDrop = 10 + TableCopy = 10 + TableFill = 10 + TableGet = 10 + TableSet = 10 + TableGrow = 10 + TableSize = 10 + AtomicNotify = 1000000 + I32AtomicWait = 1000000 + I64AtomicWait = 1000000 + AtomicFence = 1000000 + I32AtomicLoad = 1000000 + I64AtomicLoad = 1000000 + I32AtomicLoad8U = 1000000 + I32AtomicLoad16U = 1000000 + I64AtomicLoad8U = 1000000 + I64AtomicLoad16U = 1000000 + I64AtomicLoad32U = 1000000 + I32AtomicStore = 1000000 + I64AtomicStore = 1000000 + I32AtomicStore8 = 1000000 + I32AtomicStore16 = 1000000 + I64AtomicStore8 = 1000000 + I64AtomicStore16 = 1000000 + I64AtomicStore32 = 1000000 + I32AtomicRmwAdd = 1000000 + I64AtomicRmwAdd = 1000000 + I32AtomicRmw8AddU = 1000000 + I32AtomicRmw16AddU = 1000000 + I64AtomicRmw8AddU = 1000000 + I64AtomicRmw16AddU = 1000000 + I64AtomicRmw32AddU = 1000000 + I32AtomicRmwSub = 1000000 + I64AtomicRmwSub = 1000000 + I32AtomicRmw8SubU = 1000000 + I32AtomicRmw16SubU = 1000000 + I64AtomicRmw8SubU = 1000000 + I64AtomicRmw16SubU = 1000000 + I64AtomicRmw32SubU = 1000000 + I32AtomicRmwAnd = 1000000 + I64AtomicRmwAnd = 1000000 + I32AtomicRmw8AndU = 1000000 + I32AtomicRmw16AndU = 1000000 + I64AtomicRmw8AndU = 1000000 + I64AtomicRmw16AndU = 1000000 + I64AtomicRmw32AndU = 1000000 + I32AtomicRmwOr = 1000000 + I64AtomicRmwOr = 1000000 + I32AtomicRmw8OrU = 1000000 + I32AtomicRmw16OrU = 1000000 + I64AtomicRmw8OrU = 1000000 + I64AtomicRmw16OrU = 1000000 + I64AtomicRmw32OrU = 1000000 + I32AtomicRmwXor = 1000000 + I64AtomicRmwXor = 1000000 + I32AtomicRmw8XorU = 1000000 + I32AtomicRmw16XorU = 1000000 + I64AtomicRmw8XorU = 1000000 + I64AtomicRmw16XorU = 1000000 + I64AtomicRmw32XorU = 1000000 + I32AtomicRmwXchg = 1000000 + I64AtomicRmwXchg = 1000000 + I32AtomicRmw8XchgU = 1000000 + I32AtomicRmw16XchgU = 1000000 + I64AtomicRmw8XchgU = 1000000 + I64AtomicRmw16XchgU = 1000000 + I64AtomicRmw32XchgU = 1000000 + I32AtomicRmwCmpxchg = 1000000 + I64AtomicRmwCmpxchg = 1000000 + I32AtomicRmw8CmpxchgU = 1000000 + I32AtomicRmw16CmpxchgU = 1000000 + I64AtomicRmw8CmpxchgU = 1000000 + I64AtomicRmw16CmpxchgU = 1000000 + I64AtomicRmw32CmpxchgU = 1000000 + V128Load = 1000000 + V128Store = 1000000 + V128Const = 1000000 + I8x16Splat = 1000000 + I8x16ExtractLaneS = 1000000 + I8x16ExtractLaneU = 1000000 + I8x16ReplaceLane = 1000000 + I16x8Splat = 1000000 + I16x8ExtractLaneS = 1000000 + I16x8ExtractLaneU = 1000000 + I16x8ReplaceLane = 1000000 + I32x4Splat = 1000000 + I32x4ExtractLane = 1000000 + I32x4ReplaceLane = 1000000 + I64x2Splat = 1000000 + I64x2ExtractLane = 1000000 + I64x2ReplaceLane = 1000000 + F32x4Splat = 1000000 + F32x4ExtractLane = 1000000 + F32x4ReplaceLane = 1000000 + F64x2Splat = 1000000 + F64x2ExtractLane = 1000000 + F64x2ReplaceLane = 1000000 + I8x16Eq = 1000000 + I8x16Ne = 1000000 + I8x16LtS = 1000000 + I8x16LtU = 1000000 + I8x16GtS = 1000000 + I8x16GtU = 1000000 + I8x16LeS = 1000000 + I8x16LeU = 1000000 + I8x16GeS = 1000000 + I8x16GeU = 1000000 + I16x8Eq = 1000000 + I16x8Ne = 1000000 + I16x8LtS = 1000000 + I16x8LtU = 1000000 + I16x8GtS = 1000000 + I16x8GtU = 1000000 + I16x8LeS = 1000000 + I16x8LeU = 1000000 + I16x8GeS = 1000000 + I16x8GeU = 1000000 + I32x4Eq = 1000000 + I32x4Ne = 1000000 + I32x4LtS = 1000000 + I32x4LtU = 1000000 + I32x4GtS = 1000000 + I32x4GtU = 1000000 + I32x4LeS = 1000000 + I32x4LeU = 1000000 + I32x4GeS = 1000000 + I32x4GeU = 1000000 + F32x4Eq = 1000000 + F32x4Ne = 1000000 + F32x4Lt = 1000000 + F32x4Gt = 1000000 + F32x4Le = 1000000 + F32x4Ge = 1000000 + F64x2Eq = 1000000 + F64x2Ne = 1000000 + F64x2Lt = 1000000 + F64x2Gt = 1000000 + F64x2Le = 1000000 + F64x2Ge = 1000000 + V128Not = 1000000 + V128And = 1000000 + V128AndNot = 1000000 + V128Or = 1000000 + V128Xor = 1000000 + V128Bitselect = 1000000 + I8x16Neg = 1000000 + I8x16AnyTrue = 1000000 + I8x16AllTrue = 1000000 + I8x16Shl = 1000000 + I8x16ShrS = 1000000 + I8x16ShrU = 1000000 + I8x16Add = 1000000 + I8x16AddSaturateS = 1000000 + I8x16AddSaturateU = 1000000 + I8x16Sub = 1000000 + I8x16SubSaturateS = 1000000 + I8x16SubSaturateU = 1000000 + I8x16MinS = 1000000 + I8x16MinU = 1000000 + I8x16MaxS = 1000000 + I8x16MaxU = 1000000 + I8x16Mul = 1000000 + I16x8Neg = 1000000 + I16x8AnyTrue = 1000000 + I16x8AllTrue = 1000000 + I16x8Shl = 1000000 + I16x8ShrS = 1000000 + I16x8ShrU = 1000000 + I16x8Add = 1000000 + I16x8AddSaturateS = 1000000 + I16x8AddSaturateU = 1000000 + I16x8Sub = 1000000 + I16x8SubSaturateS = 1000000 + I16x8SubSaturateU = 1000000 + I16x8Mul = 1000000 + I16x8MinS = 1000000 + I16x8MinU = 1000000 + I16x8MaxS = 1000000 + I16x8MaxU = 1000000 + I32x4Neg = 1000000 + I32x4AnyTrue = 1000000 + I32x4AllTrue = 1000000 + I32x4Shl = 1000000 + I32x4ShrS = 1000000 + I32x4ShrU = 1000000 + I32x4Add = 1000000 + I32x4Sub = 1000000 + I32x4Mul = 1000000 + I32x4MinS = 1000000 + I32x4MinU = 1000000 + I32x4MaxS = 1000000 + I32x4MaxU = 1000000 + I64x2Neg = 1000000 + I64x2AnyTrue = 1000000 + I64x2AllTrue = 1000000 + I64x2Shl = 1000000 + I64x2ShrS = 1000000 + I64x2ShrU = 1000000 + I64x2Add = 1000000 + I64x2Sub = 1000000 + I64x2Mul = 1000000 + F32x4Abs = 1000000 + F32x4Neg = 1000000 + F32x4Sqrt = 1000000 + F32x4Add = 1000000 + F32x4Sub = 1000000 + F32x4Mul = 1000000 + F32x4Div = 1000000 + F32x4Min = 1000000 + F32x4Max = 1000000 + F64x2Abs = 1000000 + F64x2Neg = 1000000 + F64x2Sqrt = 1000000 + F64x2Add = 1000000 + F64x2Sub = 1000000 + F64x2Mul = 1000000 + F64x2Div = 1000000 + F64x2Min = 1000000 + F64x2Max = 1000000 + I32x4TruncSatF32x4S = 1000000 + I32x4TruncSatF32x4U = 1000000 + I64x2TruncSatF64x2S = 1000000 + I64x2TruncSatF64x2U = 1000000 + F32x4ConvertI32x4S = 1000000 + F32x4ConvertI32x4U = 1000000 + F64x2ConvertI64x2S = 1000000 + F64x2ConvertI64x2U = 1000000 + V8x16Swizzle = 1000000 + V8x16Shuffle = 1000000 + V8x16LoadSplat = 1000000 + V16x8LoadSplat = 1000000 + V32x4LoadSplat = 1000000 + V64x2LoadSplat = 1000000 + I8x16NarrowI16x8S = 1000000 + I8x16NarrowI16x8U = 1000000 + I16x8NarrowI32x4S = 1000000 + I16x8NarrowI32x4U = 1000000 + I16x8WidenLowI8x16S = 1000000 + I16x8WidenHighI8x16S = 1000000 + I16x8WidenLowI8x16U = 1000000 + I16x8WidenHighI8x16U = 1000000 + I32x4WidenLowI16x8S = 1000000 + I32x4WidenHighI16x8S = 1000000 + I32x4WidenLowI16x8U = 1000000 + I32x4WidenHighI16x8U = 1000000 + I16x8Load8x8S = 1000000 + I16x8Load8x8U = 1000000 + I32x4Load16x4S = 1000000 + I32x4Load16x4U = 1000000 + I64x2Load32x2S = 1000000 + I64x2Load32x2U = 1000000 + I8x16RoundingAverageU = 1000000 + I16x8RoundingAverageU = 1000000 + LocalAllocate = 5 + LocalsUnmetered = 100 + MaxMemoryGrowDelta = 1 + MaxMemoryGrow = 10 + Catch = 10 + CatchAll = 10 + Delegate = 10 + Rethrow = 10 + ReturnCall = 10 + ReturnCallIndirect = 10 + Throw = 10 + Try = 10 + Unwind = 10 + F32x4Ceil = 1000000 + F32x4DemoteF64x2Zero = 1000000 + F32x4Floor = 1000000 + F32x4Nearest = 1000000 + F32x4PMax = 1000000 + F32x4PMin = 1000000 + F32x4Trunc = 1000000 + F64x2Ceil = 1000000 + F64x2ConvertLowI32x4S = 1000000 + F64x2ConvertLowI32x4U = 1000000 + F64x2Floor = 1000000 + F64x2Nearest = 1000000 + F64x2PMax = 1000000 + F64x2PMin = 1000000 + F64x2PromoteLowF32x4 = 1000000 + F64x2Trunc = 1000000 + I16x8Abs = 1000000 + I16x8AddSatS = 1000000 + I16x8AddSatU = 1000000 + I16x8Bitmask = 1000000 + I16x8ExtAddPairwiseI8x16S = 1000000 + I16x8ExtAddPairwiseI8x16U = 1000000 + I16x8ExtMulHighI8x16S = 1000000 + I16x8ExtMulHighI8x16U = 1000000 + I16x8ExtMulLowI8x16S = 1000000 + I16x8ExtMulLowI8x16U = 1000000 + I16x8ExtendHighI8x16S = 1000000 + I16x8ExtendHighI8x16U = 1000000 + I16x8ExtendLowI8x16S = 1000000 + I16x8ExtendLowI8x16U = 1000000 + I16x8Q15MulrSatS = 1000000 + I16x8SubSatS = 1000000 + I16x8SubSatU = 1000000 + I32x4Abs = 1000000 + I32x4Bitmask = 1000000 + I32x4DotI16x8S = 1000000 + I32x4ExtAddPairwiseI16x8S = 1000000 + I32x4ExtAddPairwiseI16x8U = 1000000 + I32x4ExtMulHighI16x8S = 1000000 + I32x4ExtMulHighI16x8U = 1000000 + I32x4ExtMulLowI16x8S = 1000000 + I32x4ExtMulLowI16x8U = 1000000 + I32x4ExtendHighI16x8S = 1000000 + I32x4ExtendHighI16x8U = 1000000 + I32x4ExtendLowI16x8S = 1000000 + I32x4ExtendLowI16x8U = 1000000 + I32x4TruncSatF64x2SZero = 1000000 + I32x4TruncSatF64x2UZero = 1000000 + I64x2Abs = 1000000 + I64x2Bitmask = 1000000 + I64x2Eq = 1000000 + I64x2ExtMulHighI32x4S = 1000000 + I64x2ExtMulHighI32x4U = 1000000 + I64x2ExtMulLowI32x4S = 1000000 + I64x2ExtMulLowI32x4U = 1000000 + I64x2ExtendHighI32x4S = 1000000 + I64x2ExtendHighI32x4U = 1000000 + I64x2ExtendLowI32x4S = 1000000 + I64x2ExtendLowI32x4U = 1000000 + I64x2GeS = 1000000 + I64x2GtS = 1000000 + I64x2LeS = 1000000 + I64x2LtS = 1000000 + I64x2Ne = 1000000 + I8x16Abs = 1000000 + I8x16AddSatS = 1000000 + I8x16AddSatU = 1000000 + I8x16Bitmask = 1000000 + I8x16Popcnt = 1000000 + I8x16Shuffle = 1000000 + I8x16SubSatS = 1000000 + I8x16SubSatU = 1000000 + I8x16Swizzle = 1000000 + MemoryAtomicNotify = 1000000 + MemoryAtomicWait32 = 1000000 + MemoryAtomicWait64 = 1000000 + V128AnyTrue = 1000000 + V128Load16Lane = 1000000 + V128Load16Splat = 1000000 + V128Load16x4S = 1000000 + V128Load16x4U = 1000000 + V128Load32Lane = 1000000 + V128Load32Splat = 1000000 + V128Load32Zero = 1000000 + V128Load32x2S = 1000000 + V128Load32x2U = 1000000 + V128Load64Lane = 1000000 + V128Load64Splat = 1000000 + V128Load64Zero = 1000000 + V128Load8Lane = 1000000 + V128Load8Splat = 1000000 + V128Load8x8S = 1000000 + V128Load8x8U = 1000000 + V128Store16Lane = 1000000 + V128Store32Lane = 1000000 + V128Store64Lane = 1000000 + V128Store8Lane = 1000000 + +[MaxPerTransaction] + MaxBuiltInCallsPerTx = 100 + MaxNumberOfTransfersPerTx = 250 + MaxNumberOfTrieReadsPerTx = 1500 + +# Quadratic, Linear and Constant are the coefficients for a quadratic func. Separate variables are used for the +# sign of each coefficient, 0 meaning positive and 1 meaning negative +# The current values for the coefficients were computed based on benchmarking. +# For the given coefficients, the minimum of the function must not be lower than MinimumGasCost +[DynamicStorageLoad] + QuadraticCoefficient = 688 + SignOfQuadratic = 0 + LinearCoefficient = 31858 + SignOfLinear = 0 + ConstantCoefficient = 15287 + SignOfConstant = 0 + MinimumGasCost = 10000 diff --git a/cmd/node/config/nodesSetup.json b/cmd/node/config/nodesSetup.json index 741d9009ad8..daa5fd1b98a 100644 --- a/cmd/node/config/nodesSetup.json +++ b/cmd/node/config/nodesSetup.json @@ -1,12 +1,5 @@ { "startTime": 0, - "roundDuration": 6000, - "consensusGroupSize": 7, - "minNodesPerShard": 10, - "metaChainConsensusGroupSize": 10, - "metaChainMinNodes": 10, - "hysteresis": 0.2, - "adaptivity": false, "initialNodes": [ { "info": "multikey - group1 - legacy delegation", diff --git a/cmd/node/config/ratings.toml b/cmd/node/config/ratings.toml index 13edfff932a..c9cdf1c4be4 100644 --- a/cmd/node/config/ratings.toml +++ b/cmd/node/config/ratings.toml @@ -28,19 +28,30 @@ { MaxThreshold = 10000000,ChancePercent = 24}, ] -[ShardChain.RatingSteps] - HoursToMaxRatingFromStartRating = 72 - ProposerValidatorImportance = 1.0 - ProposerDecreaseFactor = -4.0 - ValidatorDecreaseFactor = -4.0 - ConsecutiveMissedBlocksPenalty = 1.50 +[ShardChain] + [[ShardChain.RatingStepsByEpoch]] + EnableEpoch = 0 + HoursToMaxRatingFromStartRating = 72 + ProposerValidatorImportance = 1.0 + ProposerDecreaseFactor = -4.0 + ValidatorDecreaseFactor = -4.0 + ConsecutiveMissedBlocksPenalty = 1.50 + [[ShardChain.RatingStepsByEpoch]] + EnableEpoch = 1 + HoursToMaxRatingFromStartRating = 55 + ProposerValidatorImportance = 1.0 + ProposerDecreaseFactor = -4.0 + ValidatorDecreaseFactor = -4.0 + ConsecutiveMissedBlocksPenalty = 1.50 -[MetaChain.RatingSteps] - HoursToMaxRatingFromStartRating = 55 - ProposerValidatorImportance = 1.0 - ProposerDecreaseFactor = -4.0 - ValidatorDecreaseFactor = -4.0 - ConsecutiveMissedBlocksPenalty = 1.50 +[MetaChain] + [[MetaChain.RatingStepsByEpoch]] + EnableEpoch = 0 + HoursToMaxRatingFromStartRating = 55 + ProposerValidatorImportance = 1.0 + ProposerDecreaseFactor = -4.0 + ValidatorDecreaseFactor = -4.0 + ConsecutiveMissedBlocksPenalty = 1.50 [PeerHonesty] #this value will be multiplied with the current value for a public key each DecayUpdateIntervalInSeconds seconds diff --git a/cmd/node/factory/interface.go b/cmd/node/factory/interface.go index 21c74696087..8f90ce3ee89 100644 --- a/cmd/node/factory/interface.go +++ b/cmd/node/factory/interface.go @@ -5,6 +5,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/p2p" ) @@ -14,6 +15,8 @@ type HeaderSigVerifierHandler interface { VerifyLeaderSignature(header data.HeaderHandler) error VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error VerifySignature(header data.HeaderHandler) error + VerifySignatureForHash(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error + VerifyHeaderProof(headerProof data.HeaderProofHandler) error IsInterfaceNil() bool } diff --git a/cmd/node/main.go b/cmd/node/main.go index 5a812bc2f45..c75dd40a393 100644 --- a/cmd/node/main.go +++ b/cmd/node/main.go @@ -238,6 +238,14 @@ func readConfigs(ctx *cli.Context, log logger.Logger) (*config.Configs, error) { } log.Debug("config", "file", configurationPaths.RoundActivation) + var nodesSetup config.NodesConfig + configurationPaths.Nodes = ctx.GlobalString(nodesFile.Name) + err = core.LoadJsonFile(&nodesSetup, configurationPaths.Nodes) + if err != nil { + return nil, err + } + log.Debug("config", "file", configurationPaths.Nodes) + if ctx.IsSet(port.Name) { mainP2PConfig.Node.Port = ctx.GlobalString(port.Name) } @@ -267,6 +275,7 @@ func readConfigs(ctx *cli.Context, log logger.Logger) (*config.Configs, error) { ConfigurationPathsHolder: configurationPaths, EpochConfig: epochConfig, RoundConfig: roundConfig, + NodesConfig: &nodesSetup, }, nil } diff --git a/common/chainparametersnotifier/chainParametersNotifier.go b/common/chainparametersnotifier/chainParametersNotifier.go new file mode 100644 index 00000000000..1a3baf2b5ff --- /dev/null +++ b/common/chainparametersnotifier/chainParametersNotifier.go @@ -0,0 +1,97 @@ +package chainparametersnotifier + +import ( + "sync" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/config" + logger "github.com/multiversx/mx-chain-logger-go" +) + +var log = logger.GetOrCreate("common/chainparameters") + +type chainParametersNotifier struct { + mutData sync.RWMutex + wasInitialized bool + currentChainParameters config.ChainParametersByEpochConfig + mutHandler sync.RWMutex + handlers []common.ChainParametersSubscriptionHandler +} + +// NewChainParametersNotifier creates a new instance of a chainParametersNotifier component +func NewChainParametersNotifier() *chainParametersNotifier { + return &chainParametersNotifier{ + wasInitialized: false, + handlers: make([]common.ChainParametersSubscriptionHandler, 0), + } +} + +// UpdateCurrentChainParameters should be called whenever new chain parameters become active on the network +func (cpn *chainParametersNotifier) UpdateCurrentChainParameters(params config.ChainParametersByEpochConfig) { + cpn.mutData.Lock() + shouldSkipParams := cpn.wasInitialized && cpn.currentChainParameters.EnableEpoch == params.EnableEpoch + if shouldSkipParams { + cpn.mutData.Unlock() + + return + } + cpn.wasInitialized = true + cpn.currentChainParameters = params + cpn.mutData.Unlock() + + cpn.mutHandler.RLock() + handlersCopy := make([]common.ChainParametersSubscriptionHandler, len(cpn.handlers)) + copy(handlersCopy, cpn.handlers) + cpn.mutHandler.RUnlock() + + log.Debug("chainParametersNotifier.UpdateCurrentChainParameters", + "enable epoch", params.EnableEpoch, + "shard consensus group size", params.ShardConsensusGroupSize, + "shard min number of nodes", params.ShardMinNumNodes, + "meta consensus group size", params.MetachainConsensusGroupSize, + "meta min number of nodes", params.MetachainMinNumNodes, + "round duration", params.RoundDuration, + "hysteresis", params.Hysteresis, + "adaptivity", params.Adaptivity, + ) + + for _, handler := range handlersCopy { + handler.ChainParametersChanged(params) + } +} + +// RegisterNotifyHandler will register the provided handler to be called whenever chain parameters have changed +func (cpn *chainParametersNotifier) RegisterNotifyHandler(handler common.ChainParametersSubscriptionHandler) { + if check.IfNil(handler) { + return + } + + cpn.mutHandler.Lock() + cpn.handlers = append(cpn.handlers, handler) + cpn.mutHandler.Unlock() + + cpn.mutData.RLock() + handler.ChainParametersChanged(cpn.currentChainParameters) + cpn.mutData.RUnlock() +} + +// CurrentChainParameters returns the current chain parameters +func (cpn *chainParametersNotifier) CurrentChainParameters() config.ChainParametersByEpochConfig { + cpn.mutData.RLock() + defer cpn.mutData.RUnlock() + + return cpn.currentChainParameters +} + +// UnRegisterAll removes all registered handlers queue +func (cpn *chainParametersNotifier) UnRegisterAll() { + cpn.mutHandler.Lock() + cpn.handlers = make([]common.ChainParametersSubscriptionHandler, 0) + cpn.mutHandler.Unlock() +} + +// IsInterfaceNil returns true if there is no value under the interface +func (cpn *chainParametersNotifier) IsInterfaceNil() bool { + return cpn == nil +} diff --git a/common/chainparametersnotifier/chainParametersNotifier_test.go b/common/chainparametersnotifier/chainParametersNotifier_test.go new file mode 100644 index 00000000000..fa1a30959d4 --- /dev/null +++ b/common/chainparametersnotifier/chainParametersNotifier_test.go @@ -0,0 +1,126 @@ +package chainparametersnotifier + +import ( + "sync" + "testing" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/config" + "github.com/stretchr/testify/require" +) + +func TestNewChainParametersNotifier(t *testing.T) { + t.Parallel() + + notifier := NewChainParametersNotifier() + require.False(t, check.IfNil(notifier)) +} + +func TestChainParametersNotifier_UpdateCurrentChainParameters(t *testing.T) { + t.Parallel() + + notifier := NewChainParametersNotifier() + require.False(t, check.IfNil(notifier)) + + chainParams := config.ChainParametersByEpochConfig{ + EnableEpoch: 7, + Adaptivity: true, + Hysteresis: 0.7, + } + notifier.UpdateCurrentChainParameters(chainParams) + + resultedChainParams := notifier.CurrentChainParameters() + require.NotNil(t, resultedChainParams) + + // update with same epoch but other params - should not change (impossible scenario in production, but easier for tests) + chainParams.Hysteresis = 0.8 + notifier.UpdateCurrentChainParameters(chainParams) + require.Equal(t, float32(0.7), notifier.CurrentChainParameters().Hysteresis) + + chainParams.Hysteresis = 0.8 + chainParams.EnableEpoch = 8 + notifier.UpdateCurrentChainParameters(chainParams) + require.Equal(t, float32(0.8), notifier.CurrentChainParameters().Hysteresis) +} + +func TestChainParametersNotifier_RegisterNotifyHandler(t *testing.T) { + t.Parallel() + + notifier := NewChainParametersNotifier() + require.False(t, check.IfNil(notifier)) + + // register a nil handler - should not panic + notifier.RegisterNotifyHandler(nil) + + testNotifee := &dummyNotifee{} + notifier.RegisterNotifyHandler(testNotifee) + + chainParams := config.ChainParametersByEpochConfig{ + ShardMinNumNodes: 37, + } + notifier.UpdateCurrentChainParameters(chainParams) + + require.Equal(t, chainParams, testNotifee.receivedChainParameters) +} + +func TestChainParametersNotifier_UnRegisterAll(t *testing.T) { + t.Parallel() + + notifier := NewChainParametersNotifier() + require.False(t, check.IfNil(notifier)) + + testNotifee := &dummyNotifee{} + notifier.RegisterNotifyHandler(testNotifee) + notifier.UnRegisterAll() + + chainParams := config.ChainParametersByEpochConfig{ + ShardMinNumNodes: 37, + } + notifier.UpdateCurrentChainParameters(chainParams) + + require.Empty(t, testNotifee.receivedChainParameters) +} + +func TestChainParametersNotifier_ConcurrentOperations(t *testing.T) { + t.Parallel() + + notifier := NewChainParametersNotifier() + + numOperations := 500 + wg := sync.WaitGroup{} + wg.Add(numOperations) + for i := 0; i < numOperations; i++ { + go func(idx int) { + switch idx { + case 0: + notifier.RegisterNotifyHandler(&dummyNotifee{}) + case 1: + _ = notifier.CurrentChainParameters() + case 2: + notifier.UpdateCurrentChainParameters(config.ChainParametersByEpochConfig{}) + case 3: + notifier.UnRegisterAll() + case 4: + _ = notifier.IsInterfaceNil() + } + + wg.Done() + }(i % 5) + } + + wg.Wait() +} + +type dummyNotifee struct { + receivedChainParameters config.ChainParametersByEpochConfig +} + +// ChainParametersChanged - +func (dn *dummyNotifee) ChainParametersChanged(chainParameters config.ChainParametersByEpochConfig) { + dn.receivedChainParameters = chainParameters +} + +// IsInterfaceNil - +func (dn *dummyNotifee) IsInterfaceNil() bool { + return dn == nil +} diff --git a/common/common.go b/common/common.go index c1e565043ad..98c6bb52f58 100644 --- a/common/common.go +++ b/common/common.go @@ -1,6 +1,31 @@ package common -import "github.com/multiversx/mx-chain-core-go/data" +import ( + "encoding/hex" + "fmt" + "math/bits" + "strconv" + "strings" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/config" + logger "github.com/multiversx/mx-chain-logger-go" +) + +const ( + keySeparator = "-" + expectedKeyLen = 2 + hashIndex = 0 + shardIndex = 1 + nonceIndex = 0 +) + +type chainParametersHandler interface { + CurrentChainParameters() config.ChainParametersByEpochConfig + ChainParametersForEpoch(epoch uint32) (config.ChainParametersByEpochConfig, error) + IsInterfaceNil() bool +} // IsValidRelayedTxV3 returns true if the provided transaction is a valid transaction of type relayed v3 func IsValidRelayedTxV3(tx data.TransactionHandler) bool { @@ -24,3 +49,156 @@ func IsRelayedTxV3(tx data.TransactionHandler) bool { hasRelayerSignature := len(relayedTx.GetRelayerSignature()) > 0 return hasRelayer || hasRelayerSignature } + +// IsEpochChangeBlockForFlagActivation returns true if the provided header is the first one after the specified flag's activation +func IsEpochChangeBlockForFlagActivation(header data.HeaderHandler, enableEpochsHandler EnableEpochsHandler, flag core.EnableEpochFlag) bool { + isStartOfEpochBlock := header.IsStartOfEpochBlock() + isBlockInActivationEpoch := header.GetEpoch() == enableEpochsHandler.GetActivationEpoch(flag) + + return isStartOfEpochBlock && isBlockInActivationEpoch +} + +// IsFlagEnabledAfterEpochsStartBlock returns true if the flag is enabled for the header, but it is not the epoch start block +func IsFlagEnabledAfterEpochsStartBlock(header data.HeaderHandler, enableEpochsHandler EnableEpochsHandler, flag core.EnableEpochFlag) bool { + isFlagEnabled := enableEpochsHandler.IsFlagEnabledInEpoch(flag, header.GetEpoch()) + isEpochStartBlock := IsEpochChangeBlockForFlagActivation(header, enableEpochsHandler, flag) + return isFlagEnabled && !isEpochStartBlock +} + +// GetShardIDs returns a map of shard IDs based on the number of shards +func GetShardIDs(numShards uint32) map[uint32]struct{} { + shardIdentifiers := make(map[uint32]struct{}) + for i := uint32(0); i < numShards; i++ { + shardIdentifiers[i] = struct{}{} + } + shardIdentifiers[core.MetachainShardId] = struct{}{} + + return shardIdentifiers +} + +// GetBitmapSize will return expected bitmap size based on provided consensus size +func GetBitmapSize( + consensusSize int, +) int { + expectedBitmapSize := consensusSize / 8 + if consensusSize%8 != 0 { + expectedBitmapSize++ + } + + return expectedBitmapSize +} + +// IsConsensusBitmapValid checks if the provided keys and bitmap match the consensus requirements +func IsConsensusBitmapValid( + log logger.Logger, + consensusPubKeys []string, + bitmap []byte, + shouldApplyFallbackValidation bool, +) error { + consensusSize := len(consensusPubKeys) + + expectedBitmapSize := GetBitmapSize(consensusSize) + if len(bitmap) != expectedBitmapSize { + log.Debug("wrong size bitmap", + "expected number of bytes", expectedBitmapSize, + "actual", len(bitmap)) + return ErrWrongSizeBitmap + } + + numOfOnesInBitmap := 0 + for index := range bitmap { + numOfOnesInBitmap += bits.OnesCount8(bitmap[index]) + } + + minNumRequiredSignatures := core.GetPBFTThreshold(consensusSize) + if shouldApplyFallbackValidation { + minNumRequiredSignatures = core.GetPBFTFallbackThreshold(consensusSize) + log.Warn("IsConsensusBitmapValid: fallback validation has been applied", + "minimum number of signatures required", minNumRequiredSignatures, + "actual number of signatures in bitmap", numOfOnesInBitmap, + ) + } + + if numOfOnesInBitmap >= minNumRequiredSignatures { + return nil + } + + log.Debug("not enough signatures", + "minimum expected", minNumRequiredSignatures, + "actual", numOfOnesInBitmap) + + return ErrNotEnoughSignatures +} + +// ConsensusGroupSizeForShardAndEpoch returns the consensus group size for a specific shard in a given epoch +func ConsensusGroupSizeForShardAndEpoch( + log logger.Logger, + chainParametersHandler chainParametersHandler, + shardID uint32, + epoch uint32, +) int { + currentChainParameters, err := chainParametersHandler.ChainParametersForEpoch(epoch) + if err != nil { + log.Warn("ConsensusGroupSizeForShardAndEpoch: could not compute chain params for epoch. "+ + "Will use the current chain parameters", "epoch", epoch, "error", err) + currentChainParameters = chainParametersHandler.CurrentChainParameters() + } + + if shardID == core.MetachainShardId { + return int(currentChainParameters.MetachainConsensusGroupSize) + } + + return int(currentChainParameters.ShardConsensusGroupSize) +} + +// GetEquivalentProofNonceShardKey returns a string key nonce-shardID +func GetEquivalentProofNonceShardKey(nonce uint64, shardID uint32) string { + return fmt.Sprintf("%d%s%d", nonce, keySeparator, shardID) +} + +// GetEquivalentProofHashShardKey returns a string key hash-shardID +func GetEquivalentProofHashShardKey(hash []byte, shardID uint32) string { + return fmt.Sprintf("%s%s%d", hex.EncodeToString(hash), keySeparator, shardID) +} + +// GetHashAndShardFromKey returns the hash and shard from the provided key +func GetHashAndShardFromKey(hashShardKey []byte) ([]byte, uint32, error) { + hashShardKeyStr := string(hashShardKey) + result := strings.Split(hashShardKeyStr, keySeparator) + if len(result) != expectedKeyLen { + return nil, 0, ErrInvalidHashShardKey + } + + hash, err := hex.DecodeString(result[hashIndex]) + if err != nil { + return nil, 0, err + } + + shard, err := strconv.Atoi(result[shardIndex]) + if err != nil { + return nil, 0, err + } + + return hash, uint32(shard), nil +} + +// GetNonceAndShardFromKey returns the nonce and shard from the provided key +func GetNonceAndShardFromKey(nonceShardKey []byte) (uint64, uint32, error) { + nonceShardKeyStr := string(nonceShardKey) + result := strings.Split(nonceShardKeyStr, keySeparator) + if len(result) != expectedKeyLen { + return 0, 0, ErrInvalidNonceShardKey + } + + nonce, err := strconv.Atoi(result[nonceIndex]) + if err != nil { + return 0, 0, err + } + + shard, err := strconv.Atoi(result[shardIndex]) + if err != nil { + return 0, 0, err + } + + return uint64(nonce), uint32(shard), nil +} diff --git a/common/common_test.go b/common/common_test.go index 5a0ec53a21f..22dd024821e 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -1,20 +1,30 @@ -package common +package common_test import ( + "errors" "math/big" "testing" + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/smartContractResult" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/stretchr/testify/require" ) +var testFlag = core.EnableEpochFlag("test flag") + func TestIsValidRelayedTxV3(t *testing.T) { t.Parallel() scr := &smartContractResult.SmartContractResult{} - require.False(t, IsValidRelayedTxV3(scr)) - require.False(t, IsRelayedTxV3(scr)) + require.False(t, common.IsValidRelayedTxV3(scr)) + require.False(t, common.IsRelayedTxV3(scr)) notRelayedTxV3 := &transaction.Transaction{ Nonce: 1, @@ -25,8 +35,8 @@ func TestIsValidRelayedTxV3(t *testing.T) { GasLimit: 10, Signature: []byte("signature"), } - require.False(t, IsValidRelayedTxV3(notRelayedTxV3)) - require.False(t, IsRelayedTxV3(notRelayedTxV3)) + require.False(t, common.IsValidRelayedTxV3(notRelayedTxV3)) + require.False(t, common.IsRelayedTxV3(notRelayedTxV3)) invalidRelayedTxV3 := &transaction.Transaction{ Nonce: 1, @@ -38,8 +48,8 @@ func TestIsValidRelayedTxV3(t *testing.T) { Signature: []byte("signature"), RelayerAddr: []byte("relayer"), } - require.False(t, IsValidRelayedTxV3(invalidRelayedTxV3)) - require.True(t, IsRelayedTxV3(invalidRelayedTxV3)) + require.False(t, common.IsValidRelayedTxV3(invalidRelayedTxV3)) + require.True(t, common.IsRelayedTxV3(invalidRelayedTxV3)) invalidRelayedTxV3 = &transaction.Transaction{ Nonce: 1, @@ -51,8 +61,8 @@ func TestIsValidRelayedTxV3(t *testing.T) { Signature: []byte("signature"), RelayerSignature: []byte("signature"), } - require.False(t, IsValidRelayedTxV3(invalidRelayedTxV3)) - require.True(t, IsRelayedTxV3(invalidRelayedTxV3)) + require.False(t, common.IsValidRelayedTxV3(invalidRelayedTxV3)) + require.True(t, common.IsRelayedTxV3(invalidRelayedTxV3)) relayedTxV3 := &transaction.Transaction{ Nonce: 1, @@ -65,6 +75,187 @@ func TestIsValidRelayedTxV3(t *testing.T) { RelayerAddr: []byte("relayer"), RelayerSignature: []byte("signature"), } - require.True(t, IsValidRelayedTxV3(relayedTxV3)) - require.True(t, IsRelayedTxV3(relayedTxV3)) + require.True(t, common.IsValidRelayedTxV3(relayedTxV3)) + require.True(t, common.IsRelayedTxV3(relayedTxV3)) +} + +func TestIsConsensusBitmapValid(t *testing.T) { + t.Parallel() + + log := &testscommon.LoggerStub{} + + pubKeys := []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"} + + t.Run("wrong size bitmap", func(t *testing.T) { + t.Parallel() + + bitmap := make([]byte, len(pubKeys)/8) + + err := common.IsConsensusBitmapValid(log, pubKeys, bitmap, false) + require.Equal(t, common.ErrWrongSizeBitmap, err) + }) + + t.Run("not enough signatures", func(t *testing.T) { + t.Parallel() + + bitmap := make([]byte, len(pubKeys)/8+1) + bitmap[0] = 0x07 + + err := common.IsConsensusBitmapValid(log, pubKeys, bitmap, false) + require.Equal(t, common.ErrNotEnoughSignatures, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + bitmap := make([]byte, len(pubKeys)/8+1) + bitmap[0] = 0x77 + bitmap[1] = 0x01 + + err := common.IsConsensusBitmapValid(log, pubKeys, bitmap, false) + require.Nil(t, err) + }) + + t.Run("should work with fallback validation", func(t *testing.T) { + t.Parallel() + + bitmap := make([]byte, len(pubKeys)/8+1) + bitmap[0] = 0x77 + bitmap[1] = 0x01 + + err := common.IsConsensusBitmapValid(log, pubKeys, bitmap, true) + require.Nil(t, err) + }) +} + +func TestIsEpochChangeBlockForFlagActivation(t *testing.T) { + t.Parallel() + + providedEpoch := uint32(123) + eeh := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + GetActivationEpochCalled: func(flag core.EnableEpochFlag) uint32 { + require.Equal(t, testFlag, flag) + return providedEpoch + }, + } + + epochStartHeaderSameEpoch := &block.HeaderV2{ + Header: &block.Header{ + EpochStartMetaHash: []byte("meta hash"), + Epoch: providedEpoch, + }, + } + notEpochStartHeaderSameEpoch := &block.HeaderV2{ + Header: &block.Header{ + Epoch: providedEpoch, + }, + } + epochStartHeaderOtherEpoch := &block.HeaderV2{ + Header: &block.Header{ + EpochStartMetaHash: []byte("meta hash"), + Epoch: providedEpoch + 1, + }, + } + notEpochStartHeaderOtherEpoch := &block.HeaderV2{ + Header: &block.Header{ + Epoch: providedEpoch + 1, + }, + } + + require.True(t, common.IsEpochChangeBlockForFlagActivation(epochStartHeaderSameEpoch, eeh, testFlag)) + require.False(t, common.IsEpochChangeBlockForFlagActivation(notEpochStartHeaderSameEpoch, eeh, testFlag)) + require.False(t, common.IsEpochChangeBlockForFlagActivation(epochStartHeaderOtherEpoch, eeh, testFlag)) + require.False(t, common.IsEpochChangeBlockForFlagActivation(notEpochStartHeaderOtherEpoch, eeh, testFlag)) +} + +func TestGetShardIDs(t *testing.T) { + t.Parallel() + + shardIDs := common.GetShardIDs(2) + require.Equal(t, 3, len(shardIDs)) + _, hasShard0 := shardIDs[0] + require.True(t, hasShard0) + _, hasShard1 := shardIDs[1] + require.True(t, hasShard1) + _, hasShardM := shardIDs[core.MetachainShardId] + require.True(t, hasShardM) +} + +func TestGetBitmapSize(t *testing.T) { + t.Parallel() + + require.Equal(t, 1, common.GetBitmapSize(8)) + require.Equal(t, 2, common.GetBitmapSize(8+1)) + require.Equal(t, 2, common.GetBitmapSize(8*2-1)) + require.Equal(t, 50, common.GetBitmapSize(8*50)) // 400 consensus size +} + +func TestConsesusGroupSizeForShardAndEpoch(t *testing.T) { + t.Parallel() + + t.Run("shard node", func(t *testing.T) { + t.Parallel() + + groupSize := uint32(400) + + size := common.ConsensusGroupSizeForShardAndEpoch( + &testscommon.LoggerStub{}, + &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(epoch uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: groupSize, + }, nil + }, + }, + 1, + 2, + ) + + require.Equal(t, int(groupSize), size) + }) + + t.Run("meta node", func(t *testing.T) { + t.Parallel() + + groupSize := uint32(400) + + size := common.ConsensusGroupSizeForShardAndEpoch( + &testscommon.LoggerStub{}, + &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(epoch uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + MetachainConsensusGroupSize: groupSize, + }, nil + }, + }, + core.MetachainShardId, + 2, + ) + + require.Equal(t, int(groupSize), size) + }) + + t.Run("on fail, use current parameters", func(t *testing.T) { + t.Parallel() + + groupSize := uint32(400) + + size := common.ConsensusGroupSizeForShardAndEpoch( + &testscommon.LoggerStub{}, + &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(epoch uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{}, errors.New("fail") + }, + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + MetachainConsensusGroupSize: groupSize, + } + }, + }, + core.MetachainShardId, + 2, + ) + + require.Equal(t, int(groupSize), size) + }) } diff --git a/common/constants.go b/common/constants.go index 9b75e52f7e9..2d2de8270da 100644 --- a/common/constants.go +++ b/common/constants.go @@ -94,6 +94,9 @@ const ConnectionTopic = "connection" // ValidatorInfoTopic is the topic used for validatorInfo signaling const ValidatorInfoTopic = "validatorInfo" +// EquivalentProofsTopic is the topic used for equivalent proofs +const EquivalentProofsTopic = "equivalentProofs" + // MetricCurrentRound is the metric for monitoring the current round of a node const MetricCurrentRound = "erd_current_round" @@ -731,8 +734,8 @@ const ( // MetricEGLDInMultiTransferEnableEpoch represents the epoch when EGLD in multi transfer feature is enabled MetricEGLDInMultiTransferEnableEpoch = "erd_egld_in_multi_transfer_enable_epoch" - // MetricCryptoOpcodesV2EnableEpoch represents the epoch when crypto opcodes v2 feature is enabled - MetricCryptoOpcodesV2EnableEpoch = "erd_crypto_opcodes_v2_enable_epoch" + // MetricCheckBuiltInCallOnTransferValueAndFailEnableRound represents the round when check builtincall on transfer value and fail is enabled + MetricCheckBuiltInCallOnTransferValueAndFailEnableRound = "erd_checkbuiltincall_ontransfervalueandfail_enable_round" // MetricMultiESDTNFTTransferAndExecuteByUserEnableEpoch represents the epoch when enshrined sovereign opcodes are enabled MetricMultiESDTNFTTransferAndExecuteByUserEnableEpoch = "erd_multi_esdt_transfer_execute_by_user_enable_epoch" @@ -764,6 +767,9 @@ const ( // MetricMaxNodesChangeEnableEpoch holds configuration for changing the maximum number of nodes and the enabling epoch MetricMaxNodesChangeEnableEpoch = "erd_max_nodes_change_enable_epoch" + // MetricCryptoOpcodesV2EnableEpoch represents the epoch when crypto opcodes v2 feature is enabled + MetricCryptoOpcodesV2EnableEpoch = "erd_crypto_opcodes_v2_enable_epoch" + // MetricEpochEnable represents the epoch when the max nodes change configuration is applied MetricEpochEnable = "erd_epoch_enable" @@ -869,10 +875,14 @@ const ( const ( // StorerOrder defines the order of storers to be notified of a start of epoch event StorerOrder = iota + // ChainParametersOrder defines the order in which ChainParameters is notified of a start of epoch event + ChainParametersOrder // NodesCoordinatorOrder defines the order in which NodesCoordinator is notified of a start of epoch event NodesCoordinatorOrder - // ConsensusOrder defines the order in which Consensus is notified of a start of epoch event - ConsensusOrder + // ConsensusHandlerOrder defines the order in which ConsensusHandler is notified of a start of epoch event + ConsensusHandlerOrder + // ConsensusStartRoundOrder defines the order in which Consensus StartRound subround is notified of a start of epoch event + ConsensusStartRoundOrder // NetworkShardingOrder defines the order in which the network sharding subsystem is notified of a start of epoch event NetworkShardingOrder // IndexerOrder defines the order in which indexer is notified of a start of epoch event @@ -981,7 +991,7 @@ const PutInStorerMaxTime = time.Second const DefaultUnstakedEpoch = math.MaxUint32 // InvalidMessageBlacklistDuration represents the time to keep a peer in the black list if it sends a message that -// does not follow the protocol: example not useing the same marshaler as the other peers +// does not follow the protocol: example not using the same marshaler as the other peers const InvalidMessageBlacklistDuration = time.Second * 3600 // PublicKeyBlacklistDuration represents the time to keep a public key in the black list if it will degrade its @@ -1136,132 +1146,134 @@ const FullArchiveMetricSuffix = "_full_archive" // Enable epoch flags definitions const ( - SCDeployFlag core.EnableEpochFlag = "SCDeployFlag" - BuiltInFunctionsFlag core.EnableEpochFlag = "BuiltInFunctionsFlag" - RelayedTransactionsFlag core.EnableEpochFlag = "RelayedTransactionsFlag" - PenalizedTooMuchGasFlag core.EnableEpochFlag = "PenalizedTooMuchGasFlag" - SwitchJailWaitingFlag core.EnableEpochFlag = "SwitchJailWaitingFlag" - BelowSignedThresholdFlag core.EnableEpochFlag = "BelowSignedThresholdFlag" - SwitchHysteresisForMinNodesFlagInSpecificEpochOnly core.EnableEpochFlag = "SwitchHysteresisForMinNodesFlagInSpecificEpochOnly" - TransactionSignedWithTxHashFlag core.EnableEpochFlag = "TransactionSignedWithTxHashFlag" - MetaProtectionFlag core.EnableEpochFlag = "MetaProtectionFlag" - AheadOfTimeGasUsageFlag core.EnableEpochFlag = "AheadOfTimeGasUsageFlag" - GasPriceModifierFlag core.EnableEpochFlag = "GasPriceModifierFlag" - RepairCallbackFlag core.EnableEpochFlag = "RepairCallbackFlag" - ReturnDataToLastTransferFlagAfterEpoch core.EnableEpochFlag = "ReturnDataToLastTransferFlagAfterEpoch" - SenderInOutTransferFlag core.EnableEpochFlag = "SenderInOutTransferFlag" - StakeFlag core.EnableEpochFlag = "StakeFlag" - StakingV2Flag core.EnableEpochFlag = "StakingV2Flag" - StakingV2OwnerFlagInSpecificEpochOnly core.EnableEpochFlag = "StakingV2OwnerFlagInSpecificEpochOnly" - StakingV2FlagAfterEpoch core.EnableEpochFlag = "StakingV2FlagAfterEpoch" - DoubleKeyProtectionFlag core.EnableEpochFlag = "DoubleKeyProtectionFlag" - ESDTFlag core.EnableEpochFlag = "ESDTFlag" - ESDTFlagInSpecificEpochOnly core.EnableEpochFlag = "ESDTFlagInSpecificEpochOnly" - GovernanceFlag core.EnableEpochFlag = "GovernanceFlag" + SCDeployFlag core.EnableEpochFlag = "SCDeployFlag" + BuiltInFunctionsFlag core.EnableEpochFlag = "BuiltInFunctionsFlag" + RelayedTransactionsFlag core.EnableEpochFlag = "RelayedTransactionsFlag" + PenalizedTooMuchGasFlag core.EnableEpochFlag = "PenalizedTooMuchGasFlag" + SwitchJailWaitingFlag core.EnableEpochFlag = "SwitchJailWaitingFlag" + BelowSignedThresholdFlag core.EnableEpochFlag = "BelowSignedThresholdFlag" + SwitchHysteresisForMinNodesFlagInSpecificEpochOnly core.EnableEpochFlag = "SwitchHysteresisForMinNodesFlagInSpecificEpochOnly" + TransactionSignedWithTxHashFlag core.EnableEpochFlag = "TransactionSignedWithTxHashFlag" + MetaProtectionFlag core.EnableEpochFlag = "MetaProtectionFlag" + AheadOfTimeGasUsageFlag core.EnableEpochFlag = "AheadOfTimeGasUsageFlag" + GasPriceModifierFlag core.EnableEpochFlag = "GasPriceModifierFlag" + RepairCallbackFlag core.EnableEpochFlag = "RepairCallbackFlag" + ReturnDataToLastTransferFlagAfterEpoch core.EnableEpochFlag = "ReturnDataToLastTransferFlagAfterEpoch" + SenderInOutTransferFlag core.EnableEpochFlag = "SenderInOutTransferFlag" + StakeFlag core.EnableEpochFlag = "StakeFlag" + StakingV2Flag core.EnableEpochFlag = "StakingV2Flag" + StakingV2OwnerFlagInSpecificEpochOnly core.EnableEpochFlag = "StakingV2OwnerFlagInSpecificEpochOnly" + StakingV2FlagAfterEpoch core.EnableEpochFlag = "StakingV2FlagAfterEpoch" + DoubleKeyProtectionFlag core.EnableEpochFlag = "DoubleKeyProtectionFlag" + ESDTFlag core.EnableEpochFlag = "ESDTFlag" + ESDTFlagInSpecificEpochOnly core.EnableEpochFlag = "ESDTFlagInSpecificEpochOnly" + GovernanceFlag core.EnableEpochFlag = "GovernanceFlag" GovernanceDisableProposeFlag core.EnableEpochFlag = "GovernanceDisableProposeFlag" GovernanceFixesFlag core.EnableEpochFlag = "GovernanceFixesFlag" - GovernanceFlagInSpecificEpochOnly core.EnableEpochFlag = "GovernanceFlagInSpecificEpochOnly" - DelegationManagerFlag core.EnableEpochFlag = "DelegationManagerFlag" - DelegationSmartContractFlag core.EnableEpochFlag = "DelegationSmartContractFlag" - DelegationSmartContractFlagInSpecificEpochOnly core.EnableEpochFlag = "DelegationSmartContractFlagInSpecificEpochOnly" - CorrectLastUnJailedFlag core.EnableEpochFlag = "CorrectLastUnJailedFlag" - CorrectLastUnJailedFlagInSpecificEpochOnly core.EnableEpochFlag = "CorrectLastUnJailedFlagInSpecificEpochOnly" - RelayedTransactionsV2Flag core.EnableEpochFlag = "RelayedTransactionsV2Flag" - UnBondTokensV2Flag core.EnableEpochFlag = "UnBondTokensV2Flag" - SaveJailedAlwaysFlag core.EnableEpochFlag = "SaveJailedAlwaysFlag" - ReDelegateBelowMinCheckFlag core.EnableEpochFlag = "ReDelegateBelowMinCheckFlag" - ValidatorToDelegationFlag core.EnableEpochFlag = "ValidatorToDelegationFlag" - IncrementSCRNonceInMultiTransferFlag core.EnableEpochFlag = "IncrementSCRNonceInMultiTransferFlag" - ESDTMultiTransferFlag core.EnableEpochFlag = "ESDTMultiTransferFlag" - GlobalMintBurnFlag core.EnableEpochFlag = "GlobalMintBurnFlag" - ESDTTransferRoleFlag core.EnableEpochFlag = "ESDTTransferRoleFlag" - ComputeRewardCheckpointFlag core.EnableEpochFlag = "ComputeRewardCheckpointFlag" - SCRSizeInvariantCheckFlag core.EnableEpochFlag = "SCRSizeInvariantCheckFlag" - BackwardCompSaveKeyValueFlag core.EnableEpochFlag = "BackwardCompSaveKeyValueFlag" - ESDTNFTCreateOnMultiShardFlag core.EnableEpochFlag = "ESDTNFTCreateOnMultiShardFlag" - MetaESDTSetFlag core.EnableEpochFlag = "MetaESDTSetFlag" - AddTokensToDelegationFlag core.EnableEpochFlag = "AddTokensToDelegationFlag" - MultiESDTTransferFixOnCallBackFlag core.EnableEpochFlag = "MultiESDTTransferFixOnCallBackFlag" - OptimizeGasUsedInCrossMiniBlocksFlag core.EnableEpochFlag = "OptimizeGasUsedInCrossMiniBlocksFlag" - CorrectFirstQueuedFlag core.EnableEpochFlag = "CorrectFirstQueuedFlag" - DeleteDelegatorAfterClaimRewardsFlag core.EnableEpochFlag = "DeleteDelegatorAfterClaimRewardsFlag" - RemoveNonUpdatedStorageFlag core.EnableEpochFlag = "RemoveNonUpdatedStorageFlag" - OptimizeNFTStoreFlag core.EnableEpochFlag = "OptimizeNFTStoreFlag" - CreateNFTThroughExecByCallerFlag core.EnableEpochFlag = "CreateNFTThroughExecByCallerFlag" - StopDecreasingValidatorRatingWhenStuckFlag core.EnableEpochFlag = "StopDecreasingValidatorRatingWhenStuckFlag" - FrontRunningProtectionFlag core.EnableEpochFlag = "FrontRunningProtectionFlag" - PayableBySCFlag core.EnableEpochFlag = "PayableBySCFlag" - CleanUpInformativeSCRsFlag core.EnableEpochFlag = "CleanUpInformativeSCRsFlag" - StorageAPICostOptimizationFlag core.EnableEpochFlag = "StorageAPICostOptimizationFlag" - ESDTRegisterAndSetAllRolesFlag core.EnableEpochFlag = "ESDTRegisterAndSetAllRolesFlag" - ScheduledMiniBlocksFlag core.EnableEpochFlag = "ScheduledMiniBlocksFlag" - CorrectJailedNotUnStakedEmptyQueueFlag core.EnableEpochFlag = "CorrectJailedNotUnStakedEmptyQueueFlag" - DoNotReturnOldBlockInBlockchainHookFlag core.EnableEpochFlag = "DoNotReturnOldBlockInBlockchainHookFlag" - AddFailedRelayedTxToInvalidMBsFlag core.EnableEpochFlag = "AddFailedRelayedTxToInvalidMBsFlag" - SCRSizeInvariantOnBuiltInResultFlag core.EnableEpochFlag = "SCRSizeInvariantOnBuiltInResultFlag" - CheckCorrectTokenIDForTransferRoleFlag core.EnableEpochFlag = "CheckCorrectTokenIDForTransferRoleFlag" - FailExecutionOnEveryAPIErrorFlag core.EnableEpochFlag = "FailExecutionOnEveryAPIErrorFlag" - MiniBlockPartialExecutionFlag core.EnableEpochFlag = "MiniBlockPartialExecutionFlag" - ManagedCryptoAPIsFlag core.EnableEpochFlag = "ManagedCryptoAPIsFlag" - ESDTMetadataContinuousCleanupFlag core.EnableEpochFlag = "ESDTMetadataContinuousCleanupFlag" - DisableExecByCallerFlag core.EnableEpochFlag = "DisableExecByCallerFlag" - RefactorContextFlag core.EnableEpochFlag = "RefactorContextFlag" - CheckFunctionArgumentFlag core.EnableEpochFlag = "CheckFunctionArgumentFlag" - CheckExecuteOnReadOnlyFlag core.EnableEpochFlag = "CheckExecuteOnReadOnlyFlag" - SetSenderInEeiOutputTransferFlag core.EnableEpochFlag = "SetSenderInEeiOutputTransferFlag" - FixAsyncCallbackCheckFlag core.EnableEpochFlag = "FixAsyncCallbackCheckFlag" - SaveToSystemAccountFlag core.EnableEpochFlag = "SaveToSystemAccountFlag" - CheckFrozenCollectionFlag core.EnableEpochFlag = "CheckFrozenCollectionFlag" - SendAlwaysFlag core.EnableEpochFlag = "SendAlwaysFlag" - ValueLengthCheckFlag core.EnableEpochFlag = "ValueLengthCheckFlag" - CheckTransferFlag core.EnableEpochFlag = "CheckTransferFlag" - ESDTNFTImprovementV1Flag core.EnableEpochFlag = "ESDTNFTImprovementV1Flag" - ChangeDelegationOwnerFlag core.EnableEpochFlag = "ChangeDelegationOwnerFlag" - RefactorPeersMiniBlocksFlag core.EnableEpochFlag = "RefactorPeersMiniBlocksFlag" - SCProcessorV2Flag core.EnableEpochFlag = "SCProcessorV2Flag" - FixAsyncCallBackArgsListFlag core.EnableEpochFlag = "FixAsyncCallBackArgsListFlag" - FixOldTokenLiquidityFlag core.EnableEpochFlag = "FixOldTokenLiquidityFlag" - RuntimeMemStoreLimitFlag core.EnableEpochFlag = "RuntimeMemStoreLimitFlag" - RuntimeCodeSizeFixFlag core.EnableEpochFlag = "RuntimeCodeSizeFixFlag" - MaxBlockchainHookCountersFlag core.EnableEpochFlag = "MaxBlockchainHookCountersFlag" - WipeSingleNFTLiquidityDecreaseFlag core.EnableEpochFlag = "WipeSingleNFTLiquidityDecreaseFlag" - AlwaysSaveTokenMetaDataFlag core.EnableEpochFlag = "AlwaysSaveTokenMetaDataFlag" - SetGuardianFlag core.EnableEpochFlag = "SetGuardianFlag" - RelayedNonceFixFlag core.EnableEpochFlag = "RelayedNonceFixFlag" - ConsistentTokensValuesLengthCheckFlag core.EnableEpochFlag = "ConsistentTokensValuesLengthCheckFlag" - KeepExecOrderOnCreatedSCRsFlag core.EnableEpochFlag = "KeepExecOrderOnCreatedSCRsFlag" - MultiClaimOnDelegationFlag core.EnableEpochFlag = "MultiClaimOnDelegationFlag" - ChangeUsernameFlag core.EnableEpochFlag = "ChangeUsernameFlag" - AutoBalanceDataTriesFlag core.EnableEpochFlag = "AutoBalanceDataTriesFlag" - MigrateDataTrieFlag core.EnableEpochFlag = "MigrateDataTrieFlag" - FixDelegationChangeOwnerOnAccountFlag core.EnableEpochFlag = "FixDelegationChangeOwnerOnAccountFlag" - FixOOGReturnCodeFlag core.EnableEpochFlag = "FixOOGReturnCodeFlag" - DeterministicSortOnValidatorsInfoFixFlag core.EnableEpochFlag = "DeterministicSortOnValidatorsInfoFixFlag" - DynamicGasCostForDataTrieStorageLoadFlag core.EnableEpochFlag = "DynamicGasCostForDataTrieStorageLoadFlag" - ScToScLogEventFlag core.EnableEpochFlag = "ScToScLogEventFlag" - BlockGasAndFeesReCheckFlag core.EnableEpochFlag = "BlockGasAndFeesReCheckFlag" - BalanceWaitingListsFlag core.EnableEpochFlag = "BalanceWaitingListsFlag" - NFTStopCreateFlag core.EnableEpochFlag = "NFTStopCreateFlag" - FixGasRemainingForSaveKeyValueFlag core.EnableEpochFlag = "FixGasRemainingForSaveKeyValueFlag" - IsChangeOwnerAddressCrossShardThroughSCFlag core.EnableEpochFlag = "IsChangeOwnerAddressCrossShardThroughSCFlag" - CurrentRandomnessOnSortingFlag core.EnableEpochFlag = "CurrentRandomnessOnSortingFlag" - StakeLimitsFlag core.EnableEpochFlag = "StakeLimitsFlag" - StakingV4Step1Flag core.EnableEpochFlag = "StakingV4Step1Flag" - StakingV4Step2Flag core.EnableEpochFlag = "StakingV4Step2Flag" - StakingV4Step3Flag core.EnableEpochFlag = "StakingV4Step3Flag" - CleanupAuctionOnLowWaitingListFlag core.EnableEpochFlag = "CleanupAuctionOnLowWaitingListFlag" - StakingV4StartedFlag core.EnableEpochFlag = "StakingV4StartedFlag" - AlwaysMergeContextsInEEIFlag core.EnableEpochFlag = "AlwaysMergeContextsInEEIFlag" - UseGasBoundedShouldFailExecutionFlag core.EnableEpochFlag = "UseGasBoundedShouldFailExecutionFlag" - DynamicESDTFlag core.EnableEpochFlag = "DynamicEsdtFlag" - EGLDInESDTMultiTransferFlag core.EnableEpochFlag = "EGLDInESDTMultiTransferFlag" - CryptoOpcodesV2Flag core.EnableEpochFlag = "CryptoOpcodesV2Flag" - UnJailCleanupFlag core.EnableEpochFlag = "UnJailCleanupFlag" - FixRelayedBaseCostFlag core.EnableEpochFlag = "FixRelayedBaseCostFlag" - MultiESDTNFTTransferAndExecuteByUserFlag core.EnableEpochFlag = "MultiESDTNFTTransferAndExecuteByUserFlag" - FixRelayedMoveBalanceToNonPayableSCFlag core.EnableEpochFlag = "FixRelayedMoveBalanceToNonPayableSCFlag" - RelayedTransactionsV3Flag core.EnableEpochFlag = "RelayedTransactionsV3Flag" - RelayedTransactionsV3FixESDTTransferFlag core.EnableEpochFlag = "RelayedTransactionsV3FixESDTTransferFlag" + GovernanceFlagInSpecificEpochOnly core.EnableEpochFlag = "GovernanceFlagInSpecificEpochOnly" + DelegationManagerFlag core.EnableEpochFlag = "DelegationManagerFlag" + DelegationSmartContractFlag core.EnableEpochFlag = "DelegationSmartContractFlag" + DelegationSmartContractFlagInSpecificEpochOnly core.EnableEpochFlag = "DelegationSmartContractFlagInSpecificEpochOnly" + CorrectLastUnJailedFlag core.EnableEpochFlag = "CorrectLastUnJailedFlag" + CorrectLastUnJailedFlagInSpecificEpochOnly core.EnableEpochFlag = "CorrectLastUnJailedFlagInSpecificEpochOnly" + RelayedTransactionsV2Flag core.EnableEpochFlag = "RelayedTransactionsV2Flag" + UnBondTokensV2Flag core.EnableEpochFlag = "UnBondTokensV2Flag" + SaveJailedAlwaysFlag core.EnableEpochFlag = "SaveJailedAlwaysFlag" + ReDelegateBelowMinCheckFlag core.EnableEpochFlag = "ReDelegateBelowMinCheckFlag" + ValidatorToDelegationFlag core.EnableEpochFlag = "ValidatorToDelegationFlag" + IncrementSCRNonceInMultiTransferFlag core.EnableEpochFlag = "IncrementSCRNonceInMultiTransferFlag" + ESDTMultiTransferFlag core.EnableEpochFlag = "ESDTMultiTransferFlag" + GlobalMintBurnFlag core.EnableEpochFlag = "GlobalMintBurnFlag" + ESDTTransferRoleFlag core.EnableEpochFlag = "ESDTTransferRoleFlag" + ComputeRewardCheckpointFlag core.EnableEpochFlag = "ComputeRewardCheckpointFlag" + SCRSizeInvariantCheckFlag core.EnableEpochFlag = "SCRSizeInvariantCheckFlag" + BackwardCompSaveKeyValueFlag core.EnableEpochFlag = "BackwardCompSaveKeyValueFlag" + ESDTNFTCreateOnMultiShardFlag core.EnableEpochFlag = "ESDTNFTCreateOnMultiShardFlag" + MetaESDTSetFlag core.EnableEpochFlag = "MetaESDTSetFlag" + AddTokensToDelegationFlag core.EnableEpochFlag = "AddTokensToDelegationFlag" + MultiESDTTransferFixOnCallBackFlag core.EnableEpochFlag = "MultiESDTTransferFixOnCallBackFlag" + OptimizeGasUsedInCrossMiniBlocksFlag core.EnableEpochFlag = "OptimizeGasUsedInCrossMiniBlocksFlag" + CorrectFirstQueuedFlag core.EnableEpochFlag = "CorrectFirstQueuedFlag" + DeleteDelegatorAfterClaimRewardsFlag core.EnableEpochFlag = "DeleteDelegatorAfterClaimRewardsFlag" + RemoveNonUpdatedStorageFlag core.EnableEpochFlag = "RemoveNonUpdatedStorageFlag" + OptimizeNFTStoreFlag core.EnableEpochFlag = "OptimizeNFTStoreFlag" + CreateNFTThroughExecByCallerFlag core.EnableEpochFlag = "CreateNFTThroughExecByCallerFlag" + StopDecreasingValidatorRatingWhenStuckFlag core.EnableEpochFlag = "StopDecreasingValidatorRatingWhenStuckFlag" + FrontRunningProtectionFlag core.EnableEpochFlag = "FrontRunningProtectionFlag" + PayableBySCFlag core.EnableEpochFlag = "PayableBySCFlag" + CleanUpInformativeSCRsFlag core.EnableEpochFlag = "CleanUpInformativeSCRsFlag" + StorageAPICostOptimizationFlag core.EnableEpochFlag = "StorageAPICostOptimizationFlag" + ESDTRegisterAndSetAllRolesFlag core.EnableEpochFlag = "ESDTRegisterAndSetAllRolesFlag" + ScheduledMiniBlocksFlag core.EnableEpochFlag = "ScheduledMiniBlocksFlag" + CorrectJailedNotUnStakedEmptyQueueFlag core.EnableEpochFlag = "CorrectJailedNotUnStakedEmptyQueueFlag" + DoNotReturnOldBlockInBlockchainHookFlag core.EnableEpochFlag = "DoNotReturnOldBlockInBlockchainHookFlag" + AddFailedRelayedTxToInvalidMBsFlag core.EnableEpochFlag = "AddFailedRelayedTxToInvalidMBsFlag" + SCRSizeInvariantOnBuiltInResultFlag core.EnableEpochFlag = "SCRSizeInvariantOnBuiltInResultFlag" + CheckCorrectTokenIDForTransferRoleFlag core.EnableEpochFlag = "CheckCorrectTokenIDForTransferRoleFlag" + FailExecutionOnEveryAPIErrorFlag core.EnableEpochFlag = "FailExecutionOnEveryAPIErrorFlag" + MiniBlockPartialExecutionFlag core.EnableEpochFlag = "MiniBlockPartialExecutionFlag" + ManagedCryptoAPIsFlag core.EnableEpochFlag = "ManagedCryptoAPIsFlag" + ESDTMetadataContinuousCleanupFlag core.EnableEpochFlag = "ESDTMetadataContinuousCleanupFlag" + DisableExecByCallerFlag core.EnableEpochFlag = "DisableExecByCallerFlag" + RefactorContextFlag core.EnableEpochFlag = "RefactorContextFlag" + CheckFunctionArgumentFlag core.EnableEpochFlag = "CheckFunctionArgumentFlag" + CheckExecuteOnReadOnlyFlag core.EnableEpochFlag = "CheckExecuteOnReadOnlyFlag" + SetSenderInEeiOutputTransferFlag core.EnableEpochFlag = "SetSenderInEeiOutputTransferFlag" + FixAsyncCallbackCheckFlag core.EnableEpochFlag = "FixAsyncCallbackCheckFlag" + SaveToSystemAccountFlag core.EnableEpochFlag = "SaveToSystemAccountFlag" + CheckFrozenCollectionFlag core.EnableEpochFlag = "CheckFrozenCollectionFlag" + SendAlwaysFlag core.EnableEpochFlag = "SendAlwaysFlag" + ValueLengthCheckFlag core.EnableEpochFlag = "ValueLengthCheckFlag" + CheckTransferFlag core.EnableEpochFlag = "CheckTransferFlag" + ESDTNFTImprovementV1Flag core.EnableEpochFlag = "ESDTNFTImprovementV1Flag" + ChangeDelegationOwnerFlag core.EnableEpochFlag = "ChangeDelegationOwnerFlag" + RefactorPeersMiniBlocksFlag core.EnableEpochFlag = "RefactorPeersMiniBlocksFlag" + SCProcessorV2Flag core.EnableEpochFlag = "SCProcessorV2Flag" + FixAsyncCallBackArgsListFlag core.EnableEpochFlag = "FixAsyncCallBackArgsListFlag" + FixOldTokenLiquidityFlag core.EnableEpochFlag = "FixOldTokenLiquidityFlag" + RuntimeMemStoreLimitFlag core.EnableEpochFlag = "RuntimeMemStoreLimitFlag" + RuntimeCodeSizeFixFlag core.EnableEpochFlag = "RuntimeCodeSizeFixFlag" + MaxBlockchainHookCountersFlag core.EnableEpochFlag = "MaxBlockchainHookCountersFlag" + WipeSingleNFTLiquidityDecreaseFlag core.EnableEpochFlag = "WipeSingleNFTLiquidityDecreaseFlag" + AlwaysSaveTokenMetaDataFlag core.EnableEpochFlag = "AlwaysSaveTokenMetaDataFlag" + SetGuardianFlag core.EnableEpochFlag = "SetGuardianFlag" + RelayedNonceFixFlag core.EnableEpochFlag = "RelayedNonceFixFlag" + ConsistentTokensValuesLengthCheckFlag core.EnableEpochFlag = "ConsistentTokensValuesLengthCheckFlag" + KeepExecOrderOnCreatedSCRsFlag core.EnableEpochFlag = "KeepExecOrderOnCreatedSCRsFlag" + MultiClaimOnDelegationFlag core.EnableEpochFlag = "MultiClaimOnDelegationFlag" + ChangeUsernameFlag core.EnableEpochFlag = "ChangeUsernameFlag" + AutoBalanceDataTriesFlag core.EnableEpochFlag = "AutoBalanceDataTriesFlag" + MigrateDataTrieFlag core.EnableEpochFlag = "MigrateDataTrieFlag" + FixDelegationChangeOwnerOnAccountFlag core.EnableEpochFlag = "FixDelegationChangeOwnerOnAccountFlag" + FixOOGReturnCodeFlag core.EnableEpochFlag = "FixOOGReturnCodeFlag" + DeterministicSortOnValidatorsInfoFixFlag core.EnableEpochFlag = "DeterministicSortOnValidatorsInfoFixFlag" + DynamicGasCostForDataTrieStorageLoadFlag core.EnableEpochFlag = "DynamicGasCostForDataTrieStorageLoadFlag" + ScToScLogEventFlag core.EnableEpochFlag = "ScToScLogEventFlag" + BlockGasAndFeesReCheckFlag core.EnableEpochFlag = "BlockGasAndFeesReCheckFlag" + BalanceWaitingListsFlag core.EnableEpochFlag = "BalanceWaitingListsFlag" + NFTStopCreateFlag core.EnableEpochFlag = "NFTStopCreateFlag" + FixGasRemainingForSaveKeyValueFlag core.EnableEpochFlag = "FixGasRemainingForSaveKeyValueFlag" + IsChangeOwnerAddressCrossShardThroughSCFlag core.EnableEpochFlag = "IsChangeOwnerAddressCrossShardThroughSCFlag" + CurrentRandomnessOnSortingFlag core.EnableEpochFlag = "CurrentRandomnessOnSortingFlag" + StakeLimitsFlag core.EnableEpochFlag = "StakeLimitsFlag" + StakingV4Step1Flag core.EnableEpochFlag = "StakingV4Step1Flag" + StakingV4Step2Flag core.EnableEpochFlag = "StakingV4Step2Flag" + StakingV4Step3Flag core.EnableEpochFlag = "StakingV4Step3Flag" + CleanupAuctionOnLowWaitingListFlag core.EnableEpochFlag = "CleanupAuctionOnLowWaitingListFlag" + StakingV4StartedFlag core.EnableEpochFlag = "StakingV4StartedFlag" + AlwaysMergeContextsInEEIFlag core.EnableEpochFlag = "AlwaysMergeContextsInEEIFlag" + UseGasBoundedShouldFailExecutionFlag core.EnableEpochFlag = "UseGasBoundedShouldFailExecutionFlag" + DynamicESDTFlag core.EnableEpochFlag = "DynamicEsdtFlag" + EGLDInESDTMultiTransferFlag core.EnableEpochFlag = "EGLDInESDTMultiTransferFlag" + CryptoOpcodesV2Flag core.EnableEpochFlag = "CryptoOpcodesV2Flag" + UnJailCleanupFlag core.EnableEpochFlag = "UnJailCleanupFlag" + FixRelayedBaseCostFlag core.EnableEpochFlag = "FixRelayedBaseCostFlag" + MultiESDTNFTTransferAndExecuteByUserFlag core.EnableEpochFlag = "MultiESDTNFTTransferAndExecuteByUserFlag" + FixRelayedMoveBalanceToNonPayableSCFlag core.EnableEpochFlag = "FixRelayedMoveBalanceToNonPayableSCFlag" + RelayedTransactionsV3Flag core.EnableEpochFlag = "RelayedTransactionsV3Flag" + RelayedTransactionsV3FixESDTTransferFlag core.EnableEpochFlag = "RelayedTransactionsV3FixESDTTransferFlag" + AndromedaFlag core.EnableEpochFlag = "AndromedaFlag" + CheckBuiltInCallOnTransferValueAndFailExecutionFlag core.EnableEpochFlag = "CheckBuiltInCallOnTransferValueAndFailExecutionFlag" MaskInternalDependenciesErrorsFlag core.EnableEpochFlag = "MaskInternalDependenciesErrorsFlag" FixBackTransferOPCODEFlag core.EnableEpochFlag = "FixBackTransferOPCODEFlag" ValidationOnGobDecodeFlag core.EnableEpochFlag = "ValidationOnGobDecodeFlag" diff --git a/common/converters.go b/common/converters.go index 036cce7d070..83ccbc084e1 100644 --- a/common/converters.go +++ b/common/converters.go @@ -23,7 +23,7 @@ func ProcessDestinationShardAsObserver(destinationShardIdAsObserver string) (uin val, err := strconv.ParseUint(destShard, 10, 32) if err != nil { - return 0, fmt.Errorf("error parsing DestinationShardAsObserver option: " + err.Error()) + return 0, fmt.Errorf("error parsing DestinationShardAsObserver option: %s", err.Error()) } return uint32(val), err diff --git a/common/enablers/enableEpochsHandler.go b/common/enablers/enableEpochsHandler.go index b4ea988789d..bd08d591a75 100644 --- a/common/enablers/enableEpochsHandler.go +++ b/common/enablers/enableEpochsHandler.go @@ -804,6 +804,19 @@ func (handler *enableEpochsHandler) createAllFlagsMap() { }, activationEpoch: handler.enableEpochsConfig.RelayedTransactionsV3FixESDTTransferEnableEpoch, }, + common.AndromedaFlag: { + isActiveInEpoch: func(epoch uint32) bool { + return epoch >= handler.enableEpochsConfig.AndromedaEnableEpoch + }, + activationEpoch: handler.enableEpochsConfig.AndromedaEnableEpoch, + }, + // TODO: move it to activation round + common.CheckBuiltInCallOnTransferValueAndFailExecutionFlag: { + isActiveInEpoch: func(epoch uint32) bool { + return epoch >= handler.enableEpochsConfig.CheckBuiltInCallOnTransferValueAndFailEnableRound + }, + activationEpoch: handler.enableEpochsConfig.CheckBuiltInCallOnTransferValueAndFailEnableRound, + }, common.AutomaticActivationOfNodesDisableFlag: { isActiveInEpoch: func(epoch uint32) bool { return epoch >= handler.enableEpochsConfig.AutomaticActivationOfNodesDisableEpoch diff --git a/common/enablers/enableEpochsHandler_test.go b/common/enablers/enableEpochsHandler_test.go index 87f8af30519..9dbbdf1fcbb 100644 --- a/common/enablers/enableEpochsHandler_test.go +++ b/common/enablers/enableEpochsHandler_test.go @@ -127,11 +127,13 @@ func createEnableEpochsConfig() config.EnableEpochs { UseGasBoundedShouldFailExecutionEnableEpoch: 108, RelayedTransactionsV3EnableEpoch: 109, RelayedTransactionsV3FixESDTTransferEnableEpoch: 110, - MaskVMInternalDependenciesErrorsEnableEpoch: 111, - FixBackTransferOPCODEEnableEpoch: 112, - ValidationOnGobDecodeEnableEpoch: 113, - BarnardOpcodesEnableEpoch: 114, - AutomaticActivationOfNodesDisableEpoch: 110, + AndromedaEnableEpoch: 111, + CheckBuiltInCallOnTransferValueAndFailEnableRound: 112, + MaskVMInternalDependenciesErrorsEnableEpoch: 113, + FixBackTransferOPCODEEnableEpoch: 114, + ValidationOnGobDecodeEnableEpoch: 115, + BarnardOpcodesEnableEpoch: 116, + AutomaticActivationOfNodesDisableEpoch: 117, } } @@ -336,6 +338,7 @@ func TestEnableEpochsHandler_IsFlagEnabled(t *testing.T) { require.True(t, handler.IsFlagEnabled(common.DynamicESDTFlag)) require.True(t, handler.IsFlagEnabled(common.FixRelayedBaseCostFlag)) require.True(t, handler.IsFlagEnabled(common.FixRelayedMoveBalanceToNonPayableSCFlag)) + require.True(t, handler.IsFlagEnabled(common.AndromedaFlag)) require.True(t, handler.IsFlagEnabled(common.DynamicESDTFlag)) } @@ -464,6 +467,8 @@ func TestEnableEpochsHandler_GetActivationEpoch(t *testing.T) { require.Equal(t, cfg.FixRelayedMoveBalanceToNonPayableSCEnableEpoch, handler.GetActivationEpoch(common.FixRelayedMoveBalanceToNonPayableSCFlag)) require.Equal(t, cfg.RelayedTransactionsV3EnableEpoch, handler.GetActivationEpoch(common.RelayedTransactionsV3Flag)) require.Equal(t, cfg.RelayedTransactionsV3FixESDTTransferEnableEpoch, handler.GetActivationEpoch(common.RelayedTransactionsV3FixESDTTransferFlag)) + require.Equal(t, cfg.AndromedaEnableEpoch, handler.GetActivationEpoch(common.AndromedaFlag)) + require.Equal(t, cfg.CheckBuiltInCallOnTransferValueAndFailEnableRound, handler.GetActivationEpoch(common.CheckBuiltInCallOnTransferValueAndFailExecutionFlag)) require.Equal(t, cfg.MaskVMInternalDependenciesErrorsEnableEpoch, handler.GetActivationEpoch(common.MaskInternalDependenciesErrorsFlag)) require.Equal(t, cfg.FixBackTransferOPCODEEnableEpoch, handler.GetActivationEpoch(common.FixBackTransferOPCODEFlag)) require.Equal(t, cfg.ValidationOnGobDecodeEnableEpoch, handler.GetActivationEpoch(common.ValidationOnGobDecodeFlag)) diff --git a/common/errors.go b/common/errors.go index 47b976de9a8..3be75377813 100644 --- a/common/errors.go +++ b/common/errors.go @@ -10,3 +10,27 @@ var ErrNilWasmChangeLocker = errors.New("nil wasm change locker") // ErrNilStateSyncNotifierSubscriber signals that a nil state sync notifier subscriber has been provided var ErrNilStateSyncNotifierSubscriber = errors.New("nil state sync notifier subscriber") + +// ErrInvalidHeaderProof signals that an invalid equivalent proof has been provided +var ErrInvalidHeaderProof = errors.New("invalid equivalent proof") + +// ErrNilHeaderProof signals that a nil equivalent proof has been provided +var ErrNilHeaderProof = errors.New("nil equivalent proof") + +// ErrAlreadyExistingEquivalentProof signals that the provided proof was already exiting in the pool +var ErrAlreadyExistingEquivalentProof = errors.New("already existing equivalent proof") + +// ErrNilHeaderHandler signals that a nil header handler has been provided +var ErrNilHeaderHandler = errors.New("nil header handler") + +// ErrNotEnoughSignatures defines the error for not enough signatures +var ErrNotEnoughSignatures = errors.New("not enough signatures") + +// ErrWrongSizeBitmap signals that the provided bitmap's length is bigger than the one that was required +var ErrWrongSizeBitmap = errors.New("wrong size bitmap has been provided") + +// ErrInvalidHashShardKey signals that the provided hash-shard key is invalid +var ErrInvalidHashShardKey = errors.New("invalid hash shard key") + +// ErrInvalidNonceShardKey signals that the provided nonce-shard key is invalid +var ErrInvalidNonceShardKey = errors.New("invalid nonce shard key") diff --git a/common/fieldsChecker/fieldsSizeChecker.go b/common/fieldsChecker/fieldsSizeChecker.go new file mode 100644 index 00000000000..ae91dd5e10e --- /dev/null +++ b/common/fieldsChecker/fieldsSizeChecker.go @@ -0,0 +1,75 @@ +package fieldsChecker + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/errors" + "github.com/multiversx/mx-chain-go/sharding" + logger "github.com/multiversx/mx-chain-logger-go" +) + +var log = logger.GetOrCreate("fieldsChecker") + +const ( + // max size for signature in bytes + sigMaxSize = 100 +) + +type fieldsSizeChecker struct { + hasher hashing.Hasher + chainParametersHandler sharding.ChainParametersHandler +} + +// NewFieldsSizeChecker will create a new fields size checker component +func NewFieldsSizeChecker( + chainParametersHandler sharding.ChainParametersHandler, + hasher hashing.Hasher, +) (*fieldsSizeChecker, error) { + if check.IfNil(chainParametersHandler) { + return nil, errors.ErrNilChainParametersHandler + } + if check.IfNil(hasher) { + return nil, core.ErrNilHasher + } + + return &fieldsSizeChecker{ + chainParametersHandler: chainParametersHandler, + hasher: hasher, + }, nil +} + +// IsProofSizeValid will check proof fields size +func (pc *fieldsSizeChecker) IsProofSizeValid(proof data.HeaderProofHandler) bool { + epochForConsensus := common.GetEpochForConsensus(proof) + + return pc.isAggregatedSigSizeValid(proof.GetAggregatedSignature()) && + pc.isBitmapSizeValid(proof.GetPubKeysBitmap(), epochForConsensus, proof.GetHeaderShardId()) && + pc.isHeaderHashSizeValid(proof.GetHeaderHash()) +} + +func (pc *fieldsSizeChecker) isBitmapSizeValid( + bitmap []byte, + epoch uint32, + shardID uint32, +) bool { + consensusSize := common.ConsensusGroupSizeForShardAndEpoch(log, pc.chainParametersHandler, shardID, epoch) + expectedBitmapSize := common.GetBitmapSize(consensusSize) + + return len(bitmap) == expectedBitmapSize +} + +func (pc *fieldsSizeChecker) isHeaderHashSizeValid(headerHash []byte) bool { + return len(headerHash) == pc.hasher.Size() +} + +func (pc *fieldsSizeChecker) isAggregatedSigSizeValid(aggSig []byte) bool { + return len(aggSig) > 0 && len(aggSig) <= sigMaxSize +} + +// IsInterfaceNil - +func (pc *fieldsSizeChecker) IsInterfaceNil() bool { + return pc == nil +} diff --git a/common/fieldsChecker/fieldsSizeChecker_test.go b/common/fieldsChecker/fieldsSizeChecker_test.go new file mode 100644 index 00000000000..be7c9c6baae --- /dev/null +++ b/common/fieldsChecker/fieldsSizeChecker_test.go @@ -0,0 +1,113 @@ +package fieldsChecker_test + +import ( + "bytes" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common/fieldsChecker" + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/errors" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" + "github.com/stretchr/testify/require" +) + +func TestNewFieldsSizeChecker(t *testing.T) { + t.Parallel() + + t.Run("nil chain parameters handler", func(t *testing.T) { + t.Parallel() + + fsc, err := fieldsChecker.NewFieldsSizeChecker( + nil, + &testscommon.HasherStub{}, + ) + require.Equal(t, errors.ErrNilChainParametersHandler, err) + require.Nil(t, fsc) + }) + + t.Run("nil hasher", func(t *testing.T) { + t.Parallel() + + fsc, err := fieldsChecker.NewFieldsSizeChecker( + &chainParameters.ChainParametersHandlerStub{}, + nil, + ) + require.Equal(t, core.ErrNilHasher, err) + require.Nil(t, fsc) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + fsc, err := fieldsChecker.NewFieldsSizeChecker( + &chainParameters.ChainParametersHandlerStub{}, + &testscommon.HasherStub{}, + ) + require.Nil(t, err) + require.NotNil(t, fsc) + require.False(t, fsc.IsInterfaceNil()) + }) +} + +func TestFieldsSizeChecker_IsProofSizeValid(t *testing.T) { + t.Parallel() + + fsc, err := fieldsChecker.NewFieldsSizeChecker( + &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(epoch uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 8, + }, nil + }, + }, + &testscommon.HasherStub{ + SizeCalled: func() int { + return 32 + }, + }, + ) + require.Nil(t, err) + + ok := fsc.IsProofSizeValid(&block.HeaderProof{ + PubKeysBitmap: []byte{1}, + AggregatedSignature: []byte("aggSig"), + HeaderHash: bytes.Repeat([]byte("h"), 32), + HeaderShardId: 1, + }) + require.True(t, ok) + + ok = fsc.IsProofSizeValid(&block.HeaderProof{ + PubKeysBitmap: []byte{1, 2}, // bigger bitmap + AggregatedSignature: []byte("aggSig"), + HeaderHash: bytes.Repeat([]byte("h"), 32), + HeaderShardId: 1, + }) + require.False(t, ok) + + ok = fsc.IsProofSizeValid(&block.HeaderProof{ + PubKeysBitmap: []byte{1}, + AggregatedSignature: []byte{}, // empty agg sig + HeaderHash: bytes.Repeat([]byte("h"), 32), + HeaderShardId: 1, + }) + require.False(t, ok) + + ok = fsc.IsProofSizeValid(&block.HeaderProof{ + PubKeysBitmap: []byte{1}, + AggregatedSignature: []byte("aggSig"), + HeaderHash: bytes.Repeat([]byte("h"), 33), // bigger hash size + HeaderShardId: 1, + }) + require.False(t, ok) + + ok = fsc.IsProofSizeValid(&block.HeaderProof{ + PubKeysBitmap: []byte{1}, + AggregatedSignature: bytes.Repeat([]byte("h"), 101), // bigger sig size + HeaderHash: bytes.Repeat([]byte("h"), 32), + HeaderShardId: 1, + }) + require.False(t, ok) +} diff --git a/common/forking/genericEpochNotifier_test.go b/common/forking/genericEpochNotifier_test.go index ca78700d2a0..a0a649c098c 100644 --- a/common/forking/genericEpochNotifier_test.go +++ b/common/forking/genericEpochNotifier_test.go @@ -1,11 +1,13 @@ package forking import ( + "sync" "sync/atomic" "testing" "time" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-go/common/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/stretchr/testify/assert" @@ -152,3 +154,33 @@ func TestGenericEpochNotifier_CheckEpochInSyncShouldWork(t *testing.T) { assert.Equal(t, uint32(2), atomic.LoadUint32(&numCalls)) assert.True(t, end.Sub(start) >= handlerWait) } + +func TestGenericEpochNotifier_ConcurrentOperations(t *testing.T) { + t.Parallel() + + notifier := NewGenericEpochNotifier() + + numOperations := 500 + wg := sync.WaitGroup{} + wg.Add(numOperations) + for i := 0; i < numOperations; i++ { + go func(idx int) { + switch idx { + case 0: + notifier.RegisterNotifyHandler(&mock.EpochSubscriberHandlerStub{}) + case 1: + _ = notifier.CurrentEpoch() + case 2: + notifier.CheckEpoch(&block.MetaBlock{Epoch: 5}) + case 3: + notifier.UnRegisterAll() + case 4: + _ = notifier.IsInterfaceNil() + } + + wg.Done() + }(i % 5) + } + + wg.Wait() +} diff --git a/common/graceperiod/epochChange.go b/common/graceperiod/epochChange.go new file mode 100644 index 00000000000..f67bb54a1ea --- /dev/null +++ b/common/graceperiod/epochChange.go @@ -0,0 +1,69 @@ +package graceperiod + +import ( + "errors" + "sort" + + "github.com/multiversx/mx-chain-go/config" +) + +var errEmptyGracePeriodByEpochConfig = errors.New("empty grace period by epoch config") +var errDuplicatedEpochConfig = errors.New("duplicated epoch config") +var errMissingEpochZeroConfig = errors.New("missing configuration for epoch 0") + +// epochChangeGracePeriod holds the grace period configuration for epoch changes +type epochChangeGracePeriod struct { + orderedConfigByEpoch []config.EpochChangeGracePeriodByEpoch +} + +// NewEpochChangeGracePeriod creates a new instance of epochChangeGracePeriod +func NewEpochChangeGracePeriod( + gracePeriodByEpoch []config.EpochChangeGracePeriodByEpoch, +) (*epochChangeGracePeriod, error) { + if len(gracePeriodByEpoch) == 0 { + return nil, errEmptyGracePeriodByEpochConfig + } + // check for duplicated configs + seen := make(map[uint32]struct{}) + for _, cfg := range gracePeriodByEpoch { + _, exists := seen[cfg.EnableEpoch] + if exists { + return nil, errDuplicatedEpochConfig + } + seen[cfg.EnableEpoch] = struct{}{} + } + + // should have a config for epoch 0 + _, exists := seen[0] + if !exists { + return nil, errMissingEpochZeroConfig + } + + ecgp := &epochChangeGracePeriod{ + orderedConfigByEpoch: make([]config.EpochChangeGracePeriodByEpoch, len(gracePeriodByEpoch)), + } + + // sort the config values in ascending order + copy(ecgp.orderedConfigByEpoch, gracePeriodByEpoch) + sort.SliceStable(ecgp.orderedConfigByEpoch, func(i, j int) bool { + return ecgp.orderedConfigByEpoch[i].EnableEpoch < ecgp.orderedConfigByEpoch[j].EnableEpoch + }) + + return ecgp, nil +} + +// GetGracePeriodForEpoch returns the grace period for the given epoch +func (ecgp *epochChangeGracePeriod) GetGracePeriodForEpoch(epoch uint32) (uint32, error) { + for i := len(ecgp.orderedConfigByEpoch) - 1; i >= 0; i-- { + if ecgp.orderedConfigByEpoch[i].EnableEpoch <= epoch { + return ecgp.orderedConfigByEpoch[i].GracePeriodInRounds, nil + } + } + + return 0, errEmptyGracePeriodByEpochConfig +} + +// IsInterfaceNil checks if the instance is nil +func (ecgp *epochChangeGracePeriod) IsInterfaceNil() bool { + return ecgp == nil +} diff --git a/common/graceperiod/epochChange_test.go b/common/graceperiod/epochChange_test.go new file mode 100644 index 00000000000..0b69840ad83 --- /dev/null +++ b/common/graceperiod/epochChange_test.go @@ -0,0 +1,123 @@ +package graceperiod + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/config" +) + +func TestNewEpochChangeGracePeriod(t *testing.T) { + t.Parallel() + + t.Run("should return error for empty config", func(t *testing.T) { + t.Parallel() + + ecgp, err := NewEpochChangeGracePeriod(nil) + require.Nil(t, ecgp) + require.Equal(t, errEmptyGracePeriodByEpochConfig, err) + }) + + t.Run("should return error for duplicate epoch configs", func(t *testing.T) { + t.Parallel() + + configs := []config.EpochChangeGracePeriodByEpoch{ + {EnableEpoch: 0, GracePeriodInRounds: 10}, + {EnableEpoch: 1, GracePeriodInRounds: 20}, + {EnableEpoch: 1, GracePeriodInRounds: 30}, + } + ecgp, err := NewEpochChangeGracePeriod(configs) + require.Nil(t, ecgp) + require.Equal(t, errDuplicatedEpochConfig, err) + }) + + t.Run("should return error for missing epoch 0 config", func(t *testing.T) { + t.Parallel() + + configs := []config.EpochChangeGracePeriodByEpoch{ + {EnableEpoch: 1, GracePeriodInRounds: 20}, + {EnableEpoch: 2, GracePeriodInRounds: 30}, + } + ecgp, err := NewEpochChangeGracePeriod(configs) + require.Nil(t, ecgp) + require.Equal(t, errMissingEpochZeroConfig, err) + }) + + t.Run("should create epochChangeGracePeriod successfully", func(t *testing.T) { + t.Parallel() + + configs := []config.EpochChangeGracePeriodByEpoch{ + {EnableEpoch: 0, GracePeriodInRounds: 10}, + {EnableEpoch: 2, GracePeriodInRounds: 30}, + {EnableEpoch: 1, GracePeriodInRounds: 20}, + } + ecgp, err := NewEpochChangeGracePeriod(configs) + require.NotNil(t, ecgp) + require.NoError(t, err) + require.Equal(t, uint32(10), ecgp.orderedConfigByEpoch[0].GracePeriodInRounds) + require.Equal(t, uint32(20), ecgp.orderedConfigByEpoch[1].GracePeriodInRounds) + require.Equal(t, uint32(30), ecgp.orderedConfigByEpoch[2].GracePeriodInRounds) + }) +} + +func TestGetGracePeriodForEpoch(t *testing.T) { + t.Parallel() + + configs := []config.EpochChangeGracePeriodByEpoch{ + {EnableEpoch: 0, GracePeriodInRounds: 10}, + {EnableEpoch: 2, GracePeriodInRounds: 30}, + {EnableEpoch: 5, GracePeriodInRounds: 50}, + } + ecgp, err := NewEpochChangeGracePeriod(configs) + require.NotNil(t, ecgp) + require.NoError(t, err) + + t.Run("should return correct grace period for matching epoch", func(t *testing.T) { + t.Parallel() + + gracePeriod, err := ecgp.GetGracePeriodForEpoch(2) + require.NoError(t, err) + require.Equal(t, uint32(30), gracePeriod) + }) + + t.Run("should return grace period for closest lower epoch", func(t *testing.T) { + t.Parallel() + + gracePeriod, err := ecgp.GetGracePeriodForEpoch(4) + require.NoError(t, err) + require.Equal(t, uint32(30), gracePeriod) + }) + + t.Run("should return grace period for higher epochs than configured", func(t *testing.T) { + t.Parallel() + + gracePeriod, err := ecgp.GetGracePeriodForEpoch(10) + require.NoError(t, err) + require.Equal(t, uint32(50), gracePeriod) + }) + + t.Run("should return error for empty config", func(t *testing.T) { + t.Parallel() + + cfg := []config.EpochChangeGracePeriodByEpoch{ + {EnableEpoch: 0, GracePeriodInRounds: 10}, + } + emptyECGP, _ := NewEpochChangeGracePeriod(cfg) + // force the config to be empty, to simulate the error case + emptyECGP.orderedConfigByEpoch = make([]config.EpochChangeGracePeriodByEpoch, 0) + gracePeriod, err := emptyECGP.GetGracePeriodForEpoch(0) + require.Equal(t, uint32(0), gracePeriod) + require.Equal(t, errEmptyGracePeriodByEpochConfig, err) + }) +} + +func TestIsInterfaceNil(t *testing.T) { + t.Parallel() + + var ecgp *epochChangeGracePeriod + require.True(t, ecgp.IsInterfaceNil()) + + ecgp = &epochChangeGracePeriod{} + require.False(t, ecgp.IsInterfaceNil()) +} diff --git a/common/interface.go b/common/interface.go index 82b2960d0ce..c7f3059ff6f 100644 --- a/common/interface.go +++ b/common/interface.go @@ -8,6 +8,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" crypto "github.com/multiversx/mx-chain-crypto-go" + + "github.com/multiversx/mx-chain-go/config" ) // TrieIteratorChannels defines the channels that are being used when iterating the trie nodes @@ -80,7 +82,10 @@ type StorageMarker interface { type KeyBuilder interface { BuildKey(keyPart []byte) GetKey() ([]byte, error) - Clone() KeyBuilder + GetRawKey() []byte + DeepClone() KeyBuilder + ShallowClone() KeyBuilder + Size() uint IsInterfaceNil() bool } @@ -222,17 +227,19 @@ type StateStatisticsHandler interface { Reset() ResetSnapshot() - IncrementCache() + IncrCache() Cache() uint64 - IncrementSnapshotCache() + IncrSnapshotCache() SnapshotCache() uint64 - IncrementPersister(epoch uint32) + IncrPersister(epoch uint32) Persister(epoch uint32) uint64 - IncrementSnapshotPersister(epoch uint32) + IncrWritePersister(epoch uint32) + WritePersister(epoch uint32) uint64 + IncrSnapshotPersister(epoch uint32) SnapshotPersister(epoch uint32) uint64 - IncrementTrie() + IncrTrie() Trie() uint64 ProcessingStats() []string @@ -372,3 +379,49 @@ type ExecutionOrderGetter interface { Len() int IsInterfaceNil() bool } + +// ChainParametersSubscriptionHandler defines the behavior of a chain parameters subscription handler +type ChainParametersSubscriptionHandler interface { + ChainParametersChanged(chainParameters config.ChainParametersByEpochConfig) + IsInterfaceNil() bool +} + +// HeadersPool defines what a headers pool structure can perform +type HeadersPool interface { + GetHeaderByHash(hash []byte) (data.HeaderHandler, error) +} + +// FieldsSizeChecker defines the behavior of a fields size checker common component +type FieldsSizeChecker interface { + IsProofSizeValid(proof data.HeaderProofHandler) bool + IsInterfaceNil() bool +} + +// EpochChangeGracePeriodHandler defines the behavior of a component that can return the grace period for a specific epoch +type EpochChangeGracePeriodHandler interface { + GetGracePeriodForEpoch(epoch uint32) (uint32, error) + IsInterfaceNil() bool +} + +// TrieNodeData is used to retrieve the data of a trie node +type TrieNodeData interface { + GetKeyBuilder() KeyBuilder + GetData() []byte + Size() uint64 + IsLeaf() bool + GetVersion() core.TrieNodeVersion +} + +// DfsIterator is used to iterate the trie nodes in a depth-first search manner +type DfsIterator interface { + GetLeaves(numLeaves int, maxSize uint64, leavesParser TrieLeafParser, ctx context.Context) (map[string]string, error) + GetIteratorState() [][]byte + IsInterfaceNil() bool +} + +// TrieLeavesRetriever is used to retrieve the leaves from the trie. If there is a saved checkpoint for the iterator id, +// it will continue to iterate from the checkpoint. +type TrieLeavesRetriever interface { + GetLeaves(numLeaves int, iteratorState [][]byte, leavesParser TrieLeafParser, ctx context.Context) (map[string]string, [][]byte, error) + IsInterfaceNil() bool +} diff --git a/common/proofs.go b/common/proofs.go new file mode 100644 index 00000000000..dc34a832f06 --- /dev/null +++ b/common/proofs.go @@ -0,0 +1,66 @@ +package common + +import ( + "fmt" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/consensus" +) + +// IsEpochStartProofForFlagActivation returns true if the provided proof is the proof of the epoch start block on the activation epoch of equivalent messages +func IsEpochStartProofForFlagActivation(proof consensus.ProofHandler, enableEpochsHandler EnableEpochsHandler) bool { + isStartOfEpochProof := proof.GetIsStartOfEpoch() + isProofInActivationEpoch := proof.GetHeaderEpoch() == enableEpochsHandler.GetActivationEpoch(AndromedaFlag) + + return isStartOfEpochProof && isProofInActivationEpoch +} + +// IsProofsFlagEnabledForHeader returns true if proofs flag has to be enabled for the provided header +func IsProofsFlagEnabledForHeader( + enableEpochsHandler EnableEpochsHandler, + header data.HeaderHandler, +) bool { + ifFlagActive := enableEpochsHandler.IsFlagEnabledInEpoch(AndromedaFlag, header.GetEpoch()) + isGenesisBlock := header.GetNonce() == 0 + + return ifFlagActive && !isGenesisBlock +} + +// VerifyProofAgainstHeader verifies the fields on the proof match the ones on the header +func VerifyProofAgainstHeader(proof data.HeaderProofHandler, header data.HeaderHandler) error { + if check.IfNil(proof) { + return ErrNilHeaderProof + } + if check.IfNil(header) { + return ErrNilHeaderHandler + } + + if proof.GetHeaderNonce() != header.GetNonce() { + return fmt.Errorf("%w, nonce mismatch", ErrInvalidHeaderProof) + } + if proof.GetHeaderShardId() != header.GetShardID() { + return fmt.Errorf("%w, shard id mismatch", ErrInvalidHeaderProof) + } + if proof.GetHeaderEpoch() != header.GetEpoch() { + return fmt.Errorf("%w, epoch mismatch", ErrInvalidHeaderProof) + } + if proof.GetHeaderRound() != header.GetRound() { + return fmt.Errorf("%w, round mismatch", ErrInvalidHeaderProof) + } + if proof.GetIsStartOfEpoch() != header.IsStartOfEpochBlock() { + return fmt.Errorf("%w, is start of epoch mismatch", ErrInvalidHeaderProof) + } + + return nil +} + +// GetEpochForConsensus will get epoch to be used by consensus based on equivalent proof data +func GetEpochForConsensus(proof data.HeaderProofHandler) uint32 { + epochForConsensus := proof.GetHeaderEpoch() + if proof.GetIsStartOfEpoch() && epochForConsensus > 0 { + epochForConsensus = epochForConsensus - 1 + } + + return epochForConsensus +} diff --git a/common/proofs_test.go b/common/proofs_test.go new file mode 100644 index 00000000000..5d3fb6b4f12 --- /dev/null +++ b/common/proofs_test.go @@ -0,0 +1,216 @@ +package common_test + +import ( + "errors" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/stretchr/testify/require" +) + +func TestIsEpochStartProofForFlagActivation(t *testing.T) { + t.Parallel() + + providedEpoch := uint32(123) + eeh := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + GetActivationEpochCalled: func(flag core.EnableEpochFlag) uint32 { + require.Equal(t, common.AndromedaFlag, flag) + return providedEpoch + }, + } + + epochStartProofSameEpoch := &block.HeaderProof{ + IsStartOfEpoch: true, + HeaderEpoch: providedEpoch, + } + notEpochStartProofSameEpoch := &block.HeaderProof{ + IsStartOfEpoch: false, + HeaderEpoch: providedEpoch, + } + epochStartProofOtherEpoch := &block.HeaderProof{ + IsStartOfEpoch: true, + HeaderEpoch: providedEpoch + 1, + } + notEpochStartProofOtherEpoch := &block.HeaderProof{ + IsStartOfEpoch: false, + HeaderEpoch: providedEpoch + 1, + } + + require.True(t, common.IsEpochStartProofForFlagActivation(epochStartProofSameEpoch, eeh)) + require.False(t, common.IsEpochStartProofForFlagActivation(notEpochStartProofSameEpoch, eeh)) + require.False(t, common.IsEpochStartProofForFlagActivation(epochStartProofOtherEpoch, eeh)) + require.False(t, common.IsEpochStartProofForFlagActivation(notEpochStartProofOtherEpoch, eeh)) +} + +func TestVerifyProofAgainstHeader(t *testing.T) { + t.Parallel() + + t.Run("nil proof or header", func(t *testing.T) { + t.Parallel() + + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("aggSig"), + HeaderHash: []byte("hash"), + HeaderEpoch: 2, + HeaderNonce: 2, + HeaderShardId: 2, + HeaderRound: 2, + IsStartOfEpoch: true, + } + + header := &block.HeaderV2{ + Header: &block.Header{ + Nonce: 2, + ShardID: 2, + Round: 2, + Epoch: 2, + EpochStartMetaHash: []byte("epoch start meta hash"), + }, + } + + err := common.VerifyProofAgainstHeader(nil, header) + require.Equal(t, common.ErrNilHeaderProof, err) + + err = common.VerifyProofAgainstHeader(proof, nil) + require.Equal(t, common.ErrNilHeaderHandler, err) + }) + + t.Run("nonce mismatch", func(t *testing.T) { + t.Parallel() + + proof := &block.HeaderProof{ + HeaderNonce: 2, + } + + header := &block.HeaderV2{ + Header: &block.Header{ + Nonce: 3, + }, + } + + err := common.VerifyProofAgainstHeader(proof, header) + require.True(t, errors.Is(err, common.ErrInvalidHeaderProof)) + }) + + t.Run("round mismatch", func(t *testing.T) { + t.Parallel() + + proof := &block.HeaderProof{ + HeaderRound: 2, + } + + header := &block.HeaderV2{ + Header: &block.Header{ + Round: 3, + }, + } + + err := common.VerifyProofAgainstHeader(proof, header) + require.True(t, errors.Is(err, common.ErrInvalidHeaderProof)) + }) + + t.Run("epoch mismatch", func(t *testing.T) { + t.Parallel() + + proof := &block.HeaderProof{ + HeaderEpoch: 2, + } + + header := &block.HeaderV2{ + Header: &block.Header{ + Epoch: 3, + }, + } + + err := common.VerifyProofAgainstHeader(proof, header) + require.True(t, errors.Is(err, common.ErrInvalidHeaderProof)) + }) + + t.Run("shard mismatch", func(t *testing.T) { + t.Parallel() + + proof := &block.HeaderProof{ + HeaderShardId: 2, + } + + header := &block.HeaderV2{ + Header: &block.Header{ + ShardID: 3, + }, + } + + err := common.VerifyProofAgainstHeader(proof, header) + require.True(t, errors.Is(err, common.ErrInvalidHeaderProof)) + }) + + t.Run("nonce mismatch", func(t *testing.T) { + t.Parallel() + + proof := &block.HeaderProof{ + IsStartOfEpoch: false, + } + + header := &block.HeaderV2{ + Header: &block.Header{ + EpochStartMetaHash: []byte("meta blockk hash"), + }, + } + + err := common.VerifyProofAgainstHeader(proof, header) + require.True(t, errors.Is(err, common.ErrInvalidHeaderProof)) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("aggSig"), + HeaderHash: []byte("hash"), + HeaderEpoch: 2, + HeaderNonce: 2, + HeaderShardId: 2, + HeaderRound: 2, + IsStartOfEpoch: true, + } + + header := &block.HeaderV2{ + Header: &block.Header{ + Nonce: 2, + ShardID: 2, + Round: 2, + Epoch: 2, + EpochStartMetaHash: []byte("epoch start meta hash"), + }, + } + + err := common.VerifyProofAgainstHeader(proof, header) + require.Nil(t, err) + + }) +} + +func TestGetEpochForConsensus(t *testing.T) { + t.Parallel() + + providedEpoch := uint32(10) + proof := &block.HeaderProof{ + HeaderEpoch: providedEpoch, + IsStartOfEpoch: false, + } + + epoch := common.GetEpochForConsensus(proof) + require.Equal(t, providedEpoch, epoch) + + proof = &block.HeaderProof{ + HeaderEpoch: providedEpoch, + IsStartOfEpoch: true, + } + + epoch = common.GetEpochForConsensus(proof) + require.Equal(t, providedEpoch-1, epoch) +} diff --git a/common/statistics/disabled/stateStatistics.go b/common/statistics/disabled/stateStatistics.go index c3bdf12420d..acc90a4fe91 100644 --- a/common/statistics/disabled/stateStatistics.go +++ b/common/statistics/disabled/stateStatistics.go @@ -19,8 +19,8 @@ func (s *stateStatistics) Reset() { func (s *stateStatistics) ResetSnapshot() { } -// IncrementCache does nothing -func (s *stateStatistics) IncrementCache() { +// IncrCache does nothing +func (s *stateStatistics) IncrCache() { } // Cache returns zero @@ -28,17 +28,17 @@ func (s *stateStatistics) Cache() uint64 { return 0 } -// IncrementSnapshotCache does nothing -func (ss *stateStatistics) IncrementSnapshotCache() { +// IncrSnapshotCache does nothing +func (s *stateStatistics) IncrSnapshotCache() { } // SnapshotCache returns the number of cached operations -func (ss *stateStatistics) SnapshotCache() uint64 { +func (s *stateStatistics) SnapshotCache() uint64 { return 0 } -// IncrementPersister does nothing -func (s *stateStatistics) IncrementPersister(epoch uint32) { +// IncrPersister does nothing +func (s *stateStatistics) IncrPersister(epoch uint32) { } // Persister returns zero @@ -46,17 +46,26 @@ func (s *stateStatistics) Persister(epoch uint32) uint64 { return 0 } -// IncrementSnapshotPersister does nothing -func (ss *stateStatistics) IncrementSnapshotPersister(epoch uint32) { +// IncrWritePersister does nothing +func (s *stateStatistics) IncrWritePersister(epoch uint32) { +} + +// WritePersister returns zero +func (s *stateStatistics) WritePersister(epoch uint32) uint64 { + return 0 +} + +// IncrSnapshotPersister does nothing +func (s *stateStatistics) IncrSnapshotPersister(epoch uint32) { } // SnapshotPersister returns the number of persister operations -func (ss *stateStatistics) SnapshotPersister(epoch uint32) uint64 { +func (s *stateStatistics) SnapshotPersister(epoch uint32) uint64 { return 0 } -// IncrementTrie does nothing -func (s *stateStatistics) IncrementTrie() { +// IncrTrie does nothing +func (s *stateStatistics) IncrTrie() { } // Trie returns zero diff --git a/common/statistics/disabled/stateStatistics_test.go b/common/statistics/disabled/stateStatistics_test.go index 725ec3ee6a1..7d5ff874204 100644 --- a/common/statistics/disabled/stateStatistics_test.go +++ b/common/statistics/disabled/stateStatistics_test.go @@ -31,16 +31,18 @@ func TestStateStatistics_MethodsShouldNotPanic(t *testing.T) { stats.ResetSnapshot() stats.ResetAll() - stats.IncrementCache() - stats.IncrementSnapshotCache() - stats.IncrementSnapshotCache() - stats.IncrementPersister(1) - stats.IncrementSnapshotPersister(1) - stats.IncrementTrie() + stats.IncrCache() + stats.IncrSnapshotCache() + stats.IncrSnapshotCache() + stats.IncrPersister(1) + stats.IncrWritePersister(1) + stats.IncrSnapshotPersister(1) + stats.IncrTrie() require.Equal(t, uint64(0), stats.Cache()) require.Equal(t, uint64(0), stats.SnapshotCache()) require.Equal(t, uint64(0), stats.Persister(1)) + require.Equal(t, uint64(0), stats.WritePersister(1)) require.Equal(t, uint64(0), stats.SnapshotPersister(1)) require.Equal(t, uint64(0), stats.Trie()) } diff --git a/common/statistics/stateStatistics.go b/common/statistics/stateStatistics.go index 474dc6d47d1..5804d3dff78 100644 --- a/common/statistics/stateStatistics.go +++ b/common/statistics/stateStatistics.go @@ -11,6 +11,7 @@ type stateStatistics struct { numSnapshotCache uint64 numPersister map[uint32]uint64 + numWritePersister map[uint32]uint64 numSnapshotPersister map[uint32]uint64 mutPersisters sync.RWMutex @@ -21,22 +22,18 @@ type stateStatistics struct { func NewStateStatistics() *stateStatistics { return &stateStatistics{ numPersister: make(map[uint32]uint64), + numWritePersister: make(map[uint32]uint64), numSnapshotPersister: make(map[uint32]uint64), } } -// ResetAll will reset all statistics -func (ss *stateStatistics) ResetAll() { - ss.Reset() - ss.ResetSnapshot() -} - // Reset will reset processing statistics func (ss *stateStatistics) Reset() { atomic.StoreUint64(&ss.numCache, 0) ss.mutPersisters.Lock() ss.numPersister = make(map[uint32]uint64) + ss.numWritePersister = make(map[uint32]uint64) ss.mutPersisters.Unlock() atomic.StoreUint64(&ss.numTrie, 0) @@ -51,8 +48,8 @@ func (ss *stateStatistics) ResetSnapshot() { ss.mutPersisters.Unlock() } -// IncrementCache will increment cache counter -func (ss *stateStatistics) IncrementCache() { +// IncrCache will increment cache counter +func (ss *stateStatistics) IncrCache() { atomic.AddUint64(&ss.numCache, 1) } @@ -61,8 +58,8 @@ func (ss *stateStatistics) Cache() uint64 { return atomic.LoadUint64(&ss.numCache) } -// IncrementSnapshotCache will increment snapshot cache counter -func (ss *stateStatistics) IncrementSnapshotCache() { +// IncrSnapshotCache will increment snapshot cache counter +func (ss *stateStatistics) IncrSnapshotCache() { atomic.AddUint64(&ss.numSnapshotCache, 1) } @@ -71,8 +68,8 @@ func (ss *stateStatistics) SnapshotCache() uint64 { return atomic.LoadUint64(&ss.numSnapshotCache) } -// IncrementPersister will increment persister counter -func (ss *stateStatistics) IncrementPersister(epoch uint32) { +// IncrPersister will increment persister counter +func (ss *stateStatistics) IncrPersister(epoch uint32) { ss.mutPersisters.Lock() defer ss.mutPersisters.Unlock() @@ -87,8 +84,24 @@ func (ss *stateStatistics) Persister(epoch uint32) uint64 { return ss.numPersister[epoch] } -// IncrementSnapshotPersister will increment snapshot persister counter -func (ss *stateStatistics) IncrementSnapshotPersister(epoch uint32) { +// IncrWitePersister will increment persister write counter +func (ss *stateStatistics) IncrWritePersister(epoch uint32) { + ss.mutPersisters.Lock() + defer ss.mutPersisters.Unlock() + + ss.numWritePersister[epoch]++ +} + +// WritePersister returns the number of write persister operations +func (ss *stateStatistics) WritePersister(epoch uint32) uint64 { + ss.mutPersisters.RLock() + defer ss.mutPersisters.RUnlock() + + return ss.numWritePersister[epoch] +} + +// IncrSnapshotPersister will increment snapshot persister counter +func (ss *stateStatistics) IncrSnapshotPersister(epoch uint32) { ss.mutPersisters.Lock() defer ss.mutPersisters.Unlock() @@ -103,8 +116,8 @@ func (ss *stateStatistics) SnapshotPersister(epoch uint32) uint64 { return ss.numSnapshotPersister[epoch] } -// IncrementTrie will increment trie counter -func (ss *stateStatistics) IncrementTrie() { +// IncrTrie will increment trie counter +func (ss *stateStatistics) IncrTrie() { atomic.AddUint64(&ss.numTrie, 1) } @@ -142,6 +155,10 @@ func (ss *stateStatistics) ProcessingStats() []string { stats = append(stats, fmt.Sprintf("persister epoch = %v op = %v", epoch, counter)) } + for epoch, counter := range ss.numWritePersister { + stats = append(stats, fmt.Sprintf("write persister epoch = %v op = %v", epoch, counter)) + } + stats = append(stats, fmt.Sprintf("trie op = %v", atomic.LoadUint64(&ss.numTrie))) return stats diff --git a/common/statistics/stateStatistics_test.go b/common/statistics/stateStatistics_test.go index 674b3d8ea6b..1599dd3bec3 100644 --- a/common/statistics/stateStatistics_test.go +++ b/common/statistics/stateStatistics_test.go @@ -27,11 +27,11 @@ func TestStateStatistics_Processing(t *testing.T) { assert.Equal(t, uint64(0), ss.Trie()) - ss.IncrementTrie() - ss.IncrementTrie() + ss.IncrTrie() + ss.IncrTrie() assert.Equal(t, uint64(2), ss.Trie()) - ss.IncrementTrie() + ss.IncrTrie() assert.Equal(t, uint64(3), ss.Trie()) ss.Reset() @@ -47,15 +47,21 @@ func TestStateStatistics_Processing(t *testing.T) { assert.Equal(t, uint64(0), ss.Persister(epoch)) - ss.IncrementPersister(epoch) - ss.IncrementPersister(epoch) + ss.IncrPersister(epoch) + ss.IncrPersister(epoch) assert.Equal(t, uint64(2), ss.Persister(epoch)) - ss.IncrementPersister(epoch) + ss.IncrPersister(epoch) assert.Equal(t, uint64(3), ss.Persister(epoch)) + ss.IncrWritePersister(epoch) + ss.IncrWritePersister(epoch) + ss.IncrWritePersister(epoch) + assert.Equal(t, uint64(3), ss.WritePersister(epoch)) + ss.Reset() assert.Equal(t, uint64(0), ss.Persister(epoch)) + assert.Equal(t, uint64(0), ss.WritePersister(epoch)) }) t.Run("cache operations", func(t *testing.T) { @@ -65,11 +71,11 @@ func TestStateStatistics_Processing(t *testing.T) { assert.Equal(t, uint64(0), ss.Cache()) - ss.IncrementCache() - ss.IncrementCache() + ss.IncrCache() + ss.IncrCache() assert.Equal(t, uint64(2), ss.Cache()) - ss.IncrementCache() + ss.IncrCache() assert.Equal(t, uint64(3), ss.Cache()) ss.Reset() @@ -89,11 +95,11 @@ func TestStateStatistics_Snapshot(t *testing.T) { assert.Equal(t, uint64(0), ss.SnapshotPersister(epoch)) - ss.IncrementSnapshotPersister(epoch) - ss.IncrementSnapshotPersister(epoch) + ss.IncrSnapshotPersister(epoch) + ss.IncrSnapshotPersister(epoch) assert.Equal(t, uint64(2), ss.SnapshotPersister(epoch)) - ss.IncrementSnapshotPersister(epoch) + ss.IncrSnapshotPersister(epoch) assert.Equal(t, uint64(3), ss.SnapshotPersister(epoch)) ss.ResetSnapshot() @@ -107,11 +113,11 @@ func TestStateStatistics_Snapshot(t *testing.T) { assert.Equal(t, uint64(0), ss.Cache()) - ss.IncrementSnapshotCache() - ss.IncrementSnapshotCache() + ss.IncrSnapshotCache() + ss.IncrSnapshotCache() assert.Equal(t, uint64(2), ss.SnapshotCache()) - ss.IncrementSnapshotCache() + ss.IncrSnapshotCache() assert.Equal(t, uint64(3), ss.SnapshotCache()) ss.ResetSnapshot() @@ -140,15 +146,15 @@ func TestStateStatistics_ConcurrenyOperations(t *testing.T) { for i := 0; i < numIterations; i++ { go func(idx int) { - switch idx % 11 { + switch idx % 13 { case 0: ss.Reset() case 1: - ss.IncrementCache() + ss.IncrCache() case 2: - ss.IncrementPersister(epoch) + ss.IncrPersister(epoch) case 3: - ss.IncrementTrie() + ss.IncrTrie() case 7: _ = ss.Cache() case 8: @@ -157,6 +163,10 @@ func TestStateStatistics_ConcurrenyOperations(t *testing.T) { _ = ss.Trie() case 10: _ = ss.ProcessingStats() + case 11: + ss.IncrWritePersister(epoch) + case 12: + _ = ss.WritePersister(epoch) } wg.Done() diff --git a/config/config.go b/config/config.go index 49ef257c341..cd3ee46781e 100644 --- a/config/config.go +++ b/config/config.go @@ -19,6 +19,12 @@ type HeadersPoolConfig struct { NumElementsToRemoveOnEviction int } +// ProofsPoolConfig will map the proofs cache configuration +type ProofsPoolConfig struct { + CleanupNonceDelta uint64 + BucketSize int +} + // DBConfig will map the database configuration type DBConfig struct { FilePath string @@ -160,13 +166,15 @@ type Config struct { BootstrapStorage StorageConfig MetaBlockStorage StorageConfig + ProofsStorage StorageConfig - AccountsTrieStorage StorageConfig - PeerAccountsTrieStorage StorageConfig - EvictionWaitingList EvictionWaitingListConfig - StateTriesConfig StateTriesConfig - TrieStorageManagerConfig TrieStorageManagerConfig - BadBlocksCache CacheConfig + AccountsTrieStorage StorageConfig + PeerAccountsTrieStorage StorageConfig + EvictionWaitingList EvictionWaitingListConfig + StateTriesConfig StateTriesConfig + TrieStorageManagerConfig TrieStorageManagerConfig + TrieLeavesRetrieverConfig TrieLeavesRetrieverConfig + BadBlocksCache CacheConfig TxBlockBodyDataPool CacheConfig PeerBlockBodyDataPool CacheConfig @@ -208,6 +216,7 @@ type Config struct { NTPConfig NTPConfig HeadersPoolConfig HeadersPoolConfig + ProofsPoolConfig ProofsPoolConfig BlockSizeThrottleConfig BlockSizeThrottleConfig VirtualMachine VirtualMachineServicesConfig BuiltInFunctions BuiltInFunctionsConfig @@ -228,6 +237,8 @@ type Config struct { PeersRatingConfig PeersRatingConfig PoolsCleanersConfig PoolsCleanersConfig Redundancy RedundancyConfig + + InterceptedDataVerifier InterceptedDataVerifierConfig } // PeersRatingConfig will hold settings related to peers rating @@ -278,6 +289,12 @@ type MultiSignerConfig struct { Type string } +// EpochChangeGracePeriodByEpoch defines a config tuple for the epoch change grace period +type EpochChangeGracePeriodByEpoch struct { + EnableEpoch uint32 + GracePeriodInRounds uint32 +} + // GeneralSettingsConfig will hold the general settings for a node type GeneralSettingsConfig struct { StatusPollingIntervalSec int @@ -290,6 +307,8 @@ type GeneralSettingsConfig struct { GenesisMaxNumberOfShards uint32 SyncProcessTimeInMillis uint32 SetGuardianEpochsDelay uint32 + ChainParametersByEpoch []ChainParametersByEpochConfig + EpochChangeGracePeriodByEpoch []EpochChangeGracePeriodByEpoch } // HardwareRequirementsConfig will hold the hardware requirements config @@ -591,6 +610,20 @@ type Configs struct { ConfigurationPathsHolder *ConfigurationPathsHolder EpochConfig *EpochConfig RoundConfig *RoundConfig + NodesConfig *NodesConfig +} + +// NodesConfig is the data transfer object used to map the nodes' configuration in regard to the genesis nodes setup +type NodesConfig struct { + StartTime int64 `json:"startTime"` + InitialNodes []*InitialNodeConfig `json:"initialNodes"` +} + +// InitialNodeConfig holds data about a genesis node +type InitialNodeConfig struct { + PubKey string `json:"pubkey"` + Address string `json:"address"` + InitialRating uint32 `json:"initialRating"` } // ConfigurationPathsHolder holds all configuration filenames and configuration paths used to start the node @@ -640,3 +673,33 @@ type PoolsCleanersConfig struct { type RedundancyConfig struct { MaxRoundsOfInactivityAccepted int } + +// ChainParametersByEpochConfig holds chain parameters that are configurable based on epochs +type ChainParametersByEpochConfig struct { + RoundDuration uint64 + Hysteresis float32 + EnableEpoch uint32 + ShardConsensusGroupSize uint32 + ShardMinNumNodes uint32 + MetachainConsensusGroupSize uint32 + MetachainMinNumNodes uint32 + Adaptivity bool +} + +// IndexBroadcastDelay holds a pair of starting consensus index and the delay the nodes should wait before broadcasting final info +type IndexBroadcastDelay struct { + EndIndex int + DelayInMilliseconds uint64 +} + +// InterceptedDataVerifierConfig holds the configuration for the intercepted data verifier +type InterceptedDataVerifierConfig struct { + CacheSpanInSec uint64 + CacheExpiryInSec uint64 +} + +// TrieLeavesRetrieverConfig represents the config options to be used when setting up the trie leaves retriever +type TrieLeavesRetrieverConfig struct { + Enabled bool + MaxSizeInBytes uint64 +} diff --git a/config/epochConfig.go b/config/epochConfig.go index 491f111558b..94a022c866b 100644 --- a/config/epochConfig.go +++ b/config/epochConfig.go @@ -126,6 +126,8 @@ type EnableEpochs struct { FixRelayedMoveBalanceToNonPayableSCEnableEpoch uint32 RelayedTransactionsV3EnableEpoch uint32 RelayedTransactionsV3FixESDTTransferEnableEpoch uint32 + AndromedaEnableEpoch uint32 + CheckBuiltInCallOnTransferValueAndFailEnableRound uint32 MaskVMInternalDependenciesErrorsEnableEpoch uint32 FixBackTransferOPCODEEnableEpoch uint32 ValidationOnGobDecodeEnableEpoch uint32 diff --git a/config/ratingsConfig.go b/config/ratingsConfig.go index a4c243cd51b..c20ca1aa93b 100644 --- a/config/ratingsConfig.go +++ b/config/ratingsConfig.go @@ -19,12 +19,12 @@ type General struct { // ShardChain will hold RatingSteps for the Shard type ShardChain struct { - RatingSteps + RatingStepsByEpoch []RatingSteps } // MetaChain will hold RatingSteps for the Meta type MetaChain struct { - RatingSteps + RatingStepsByEpoch []RatingSteps } // RatingValue will hold different rating options with increase and decrease steps @@ -46,6 +46,7 @@ type RatingSteps struct { ProposerDecreaseFactor float32 ValidatorDecreaseFactor float32 ConsecutiveMissedBlocksPenalty float32 + EnableEpoch uint32 } // PeerHonestyConfig holds the parameters for the peer honesty handler diff --git a/config/tomlConfig_test.go b/config/tomlConfig_test.go index efd6633b43f..c284edbd0ae 100644 --- a/config/tomlConfig_test.go +++ b/config/tomlConfig_test.go @@ -48,6 +48,20 @@ func TestTomlParser(t *testing.T) { } cfgExpected := Config{ + GeneralSettings: GeneralSettingsConfig{ + ChainParametersByEpoch: []ChainParametersByEpochConfig{ + { + EnableEpoch: 0, + RoundDuration: 4000, + ShardMinNumNodes: 4, + ShardConsensusGroupSize: 3, + MetachainMinNumNodes: 6, + MetachainConsensusGroupSize: 5, + Hysteresis: 0.0, + Adaptivity: false, + }, + }, + }, MiniBlocksStorage: StorageConfig{ Cache: CacheConfig{ Capacity: uint32(txBlockBodyStorageSize), @@ -150,6 +164,10 @@ func TestTomlParser(t *testing.T) { }, } testString := ` +[GeneralSettings] + ChainParametersByEpoch = [ + { EnableEpoch = 0, RoundDuration = 4000, ShardConsensusGroupSize = 3, ShardMinNumNodes = 4, MetachainConsensusGroupSize = 5, MetachainMinNumNodes = 6, Hysteresis = 0.0, Adaptivity = false } + ] [MiniBlocksStorage] [MiniBlocksStorage.Cache] Capacity = ` + strconv.Itoa(txBlockBodyStorageSize) + ` @@ -899,20 +917,26 @@ func TestEnableEpochConfig(t *testing.T) { # RelayedTransactionsV3FixESDTTransferEnableEpoch represents the epoch when the fix for relayed transactions v3 with esdt transfer will be enabled RelayedTransactionsV3FixESDTTransferEnableEpoch = 104 + # AndromedaEnableEpoch represents the epoch when the equivalent messages are enabled + AndromedaEnableEpoch = 105 + + # CheckBuiltInCallOnTransferValueAndFailEnableRound represents the ROUND when the check on transfer value fix is activated + CheckBuiltInCallOnTransferValueAndFailEnableRound = 106 + # MaskVMInternalDependenciesErrorsEnableEpoch represents the epoch when the additional internal erorr masking in vm is enabled - MaskVMInternalDependenciesErrorsEnableEpoch = 105 + MaskVMInternalDependenciesErrorsEnableEpoch = 107 # FixBackTransferOPCODEEnableEpoch represents the epoch when the fix for back transfers opcode will be enabled - FixBackTransferOPCODEEnableEpoch = 106 + FixBackTransferOPCODEEnableEpoch = 108 # ValidationOnGobDecodeEnableEpoch represents the epoch when validation on GobDecode will be taken into account - ValidationOnGobDecodeEnableEpoch = 107 + ValidationOnGobDecodeEnableEpoch = 109 # BarnardOpcodesEnableEpoch represents the epoch when Barnard opcodes will be enabled - BarnardOpcodesEnableEpoch = 108 + BarnardOpcodesEnableEpoch = 110 # AutomaticActivationOfNodesDisableEpoch represents the epoch when automatic activation of nodes for validators is disabled - AutomaticActivationOfNodesDisableEpoch = 104 + AutomaticActivationOfNodesDisableEpoch = 111 # MaxNodesChangeEnableEpoch holds configuration for changing the maximum number of nodes and the enabling epoch MaxNodesChangeEnableEpoch = [ @@ -1038,11 +1062,13 @@ func TestEnableEpochConfig(t *testing.T) { FixRelayedMoveBalanceToNonPayableSCEnableEpoch: 102, RelayedTransactionsV3EnableEpoch: 103, RelayedTransactionsV3FixESDTTransferEnableEpoch: 104, - MaskVMInternalDependenciesErrorsEnableEpoch: 105, - FixBackTransferOPCODEEnableEpoch: 106, - ValidationOnGobDecodeEnableEpoch: 107, - BarnardOpcodesEnableEpoch: 108, - AutomaticActivationOfNodesDisableEpoch: 104, + AndromedaEnableEpoch: 105, + CheckBuiltInCallOnTransferValueAndFailEnableRound: 106, + MaskVMInternalDependenciesErrorsEnableEpoch: 107, + FixBackTransferOPCODEEnableEpoch: 108, + ValidationOnGobDecodeEnableEpoch: 109, + BarnardOpcodesEnableEpoch: 110, + AutomaticActivationOfNodesDisableEpoch: 111, MaxNodesChangeEnableEpoch: []MaxNodesChangeConfig{ { EpochEnable: 44, diff --git a/consensus/broadcast/commonMessenger.go b/consensus/broadcast/commonMessenger.go index 60c59e01145..e28c39defec 100644 --- a/consensus/broadcast/commonMessenger.go +++ b/consensus/broadcast/commonMessenger.go @@ -11,37 +11,25 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/sharding" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("consensus/broadcast") -// delayedBroadcaster exposes functionality for handling the consensus members broadcasting of delay data -type delayedBroadcaster interface { - SetLeaderData(data *delayedBroadcastData) error - SetValidatorData(data *delayedBroadcastData) error - SetHeaderForValidator(vData *validatorHeaderBroadcastData) error - SetBroadcastHandlers( - mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, - txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, - headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, - ) error - Close() -} - type commonMessenger struct { marshalizer marshal.Marshalizer hasher hashing.Hasher messenger consensus.P2PMessenger shardCoordinator sharding.Coordinator peerSignatureHandler crypto.PeerSignatureHandler - delayedBlockBroadcaster delayedBroadcaster + delayedBlockBroadcaster DelayedBroadcaster keysHandler consensus.KeysHandler } @@ -58,6 +46,7 @@ type CommonMessengerArgs struct { MaxValidatorDelayCacheSize uint32 AlarmScheduler core.TimersScheduler KeysHandler consensus.KeysHandler + DelayedBroadcaster DelayedBroadcaster } func checkCommonMessengerNilParameters( @@ -93,6 +82,9 @@ func checkCommonMessengerNilParameters( if check.IfNil(args.KeysHandler) { return ErrNilKeysHandler } + if check.IfNil(args.DelayedBroadcaster) { + return ErrNilDelayedBroadcaster + } return nil } @@ -241,3 +233,18 @@ func (cm *commonMessenger) broadcast(topic string, data []byte, pkBytes []byte) cm.messenger.BroadcastUsingPrivateKey(topic, data, pid, skBytes) } + +func (cm *commonMessenger) broadcastEquivalentProof(proof data.HeaderProofHandler, pkBytes []byte, topic string) error { + if check.IfNil(proof) { + return spos.ErrNilHeaderProof + } + + msgProof, err := cm.marshalizer.Marshal(proof) + if err != nil { + return err + } + + cm.broadcast(topic, msgProof, pkBytes) + + return nil +} diff --git a/consensus/broadcast/delayedBroadcast.go b/consensus/broadcast/delayedBroadcast.go index 511ac6d79e6..9f67dcbc248 100644 --- a/consensus/broadcast/delayedBroadcast.go +++ b/consensus/broadcast/delayedBroadcast.go @@ -11,8 +11,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" @@ -35,25 +37,6 @@ type ArgsDelayedBlockBroadcaster struct { AlarmScheduler timersScheduler } -type validatorHeaderBroadcastData struct { - headerHash []byte - header data.HeaderHandler - metaMiniBlocksData map[uint32][]byte - metaTransactionsData map[string][][]byte - order uint32 - pkBytes []byte -} - -type delayedBroadcastData struct { - headerHash []byte - header data.HeaderHandler - miniBlocksData map[uint32][]byte - miniBlockHashes map[string]map[string]struct{} - transactions map[string][][]byte - order uint32 - pkBytes []byte -} - // timersScheduler exposes functionality for scheduling multiple timers type timersScheduler interface { Add(callback func(alarmID string), duration time.Duration, alarmID string) @@ -72,15 +55,16 @@ type delayedBlockBroadcaster struct { interceptorsContainer process.InterceptorsContainer shardCoordinator sharding.Coordinator headersSubscriber consensus.HeadersPoolSubscriber - valHeaderBroadcastData []*validatorHeaderBroadcastData - valBroadcastData []*delayedBroadcastData - delayedBroadcastData []*delayedBroadcastData + valHeaderBroadcastData []*shared.ValidatorHeaderBroadcastData + valBroadcastData []*shared.DelayedBroadcastData + delayedBroadcastData []*shared.DelayedBroadcastData maxDelayCacheSize uint32 maxValidatorDelayCacheSize uint32 mutDataForBroadcast sync.RWMutex broadcastMiniblocksData func(mbData map[uint32][]byte, pkBytes []byte) error broadcastTxsData func(txData map[string][][]byte, pkBytes []byte) error broadcastHeader func(header data.HeaderHandler, pkBytes []byte) error + broadcastConsensusMessage func(message *consensus.Message) error cacheHeaders storage.Cacher mutHeadersCache sync.RWMutex } @@ -110,9 +94,9 @@ func NewDelayedBlockBroadcaster(args *ArgsDelayedBlockBroadcaster) (*delayedBloc shardCoordinator: args.ShardCoordinator, interceptorsContainer: args.InterceptorsContainer, headersSubscriber: args.HeadersSubscriber, - valHeaderBroadcastData: make([]*validatorHeaderBroadcastData, 0), - valBroadcastData: make([]*delayedBroadcastData, 0), - delayedBroadcastData: make([]*delayedBroadcastData, 0), + valHeaderBroadcastData: make([]*shared.ValidatorHeaderBroadcastData, 0), + valBroadcastData: make([]*shared.DelayedBroadcastData, 0), + delayedBroadcastData: make([]*shared.DelayedBroadcastData, 0), maxDelayCacheSize: args.LeaderCacheSize, maxValidatorDelayCacheSize: args.ValidatorCacheSize, mutDataForBroadcast: sync.RWMutex{}, @@ -135,22 +119,22 @@ func NewDelayedBlockBroadcaster(args *ArgsDelayedBlockBroadcaster) (*delayedBloc } // SetLeaderData sets the data for consensus leader delayed broadcast -func (dbb *delayedBlockBroadcaster) SetLeaderData(broadcastData *delayedBroadcastData) error { +func (dbb *delayedBlockBroadcaster) SetLeaderData(broadcastData *shared.DelayedBroadcastData) error { if broadcastData == nil { return spos.ErrNilParameter } log.Trace("delayedBlockBroadcaster.SetLeaderData: setting leader delay data", - "headerHash", broadcastData.headerHash, + "headerHash", broadcastData.HeaderHash, ) - dataToBroadcast := make([]*delayedBroadcastData, 0) + dataToBroadcast := make([]*shared.DelayedBroadcastData, 0) dbb.mutDataForBroadcast.Lock() dbb.delayedBroadcastData = append(dbb.delayedBroadcastData, broadcastData) if len(dbb.delayedBroadcastData) > int(dbb.maxDelayCacheSize) { log.Debug("delayedBlockBroadcaster.SetLeaderData: leader broadcasts old data before alarm due to too much delay data", - "headerHash", dbb.delayedBroadcastData[0].headerHash, + "headerHash", dbb.delayedBroadcastData[0].HeaderHash, "nbDelayedData", len(dbb.delayedBroadcastData), "maxDelayCacheSize", dbb.maxDelayCacheSize, ) @@ -167,14 +151,17 @@ func (dbb *delayedBlockBroadcaster) SetLeaderData(broadcastData *delayedBroadcas } // SetHeaderForValidator sets the header to be broadcast by validator if leader fails to broadcast it -func (dbb *delayedBlockBroadcaster) SetHeaderForValidator(vData *validatorHeaderBroadcastData) error { - if check.IfNil(vData.header) { +func (dbb *delayedBlockBroadcaster) SetHeaderForValidator(vData *shared.ValidatorHeaderBroadcastData) error { + if check.IfNil(vData.Header) { return spos.ErrNilHeader } - if len(vData.headerHash) == 0 { + if len(vData.HeaderHash) == 0 { return spos.ErrNilHeaderHash } + dbb.mutDataForBroadcast.Lock() + defer dbb.mutDataForBroadcast.Unlock() + log.Trace("delayedBlockBroadcaster.SetHeaderForValidator", "nbDelayedBroadcastData", len(dbb.delayedBroadcastData), "nbValBroadcastData", len(dbb.valBroadcastData), @@ -182,25 +169,25 @@ func (dbb *delayedBlockBroadcaster) SetHeaderForValidator(vData *validatorHeader ) // set alarm only for validators that are aware that the block was finalized - if len(vData.header.GetSignature()) != 0 { - _, alreadyReceived := dbb.cacheHeaders.Get(vData.headerHash) + if len(vData.Header.GetSignature()) != 0 { + _, alreadyReceived := dbb.cacheHeaders.Get(vData.HeaderHash) if alreadyReceived { return nil } - duration := validatorDelayPerOrder * time.Duration(vData.order) + duration := validatorDelayPerOrder * time.Duration(vData.Order) dbb.valHeaderBroadcastData = append(dbb.valHeaderBroadcastData, vData) - alarmID := prefixHeaderAlarm + hex.EncodeToString(vData.headerHash) + alarmID := prefixHeaderAlarm + hex.EncodeToString(vData.HeaderHash) dbb.alarm.Add(dbb.headerAlarmExpired, duration, alarmID) log.Trace("delayedBlockBroadcaster.SetHeaderForValidator: header alarm has been set", - "validatorConsensusOrder", vData.order, - "headerHash", vData.headerHash, + "validatorConsensusOrder", vData.Order, + "headerHash", vData.HeaderHash, "alarmID", alarmID, "duration", duration, ) } else { log.Trace("delayedBlockBroadcaster.SetHeaderForValidator: header alarm has not been set", - "validatorConsensusOrder", vData.order, + "validatorConsensusOrder", vData.Order, ) } @@ -208,29 +195,29 @@ func (dbb *delayedBlockBroadcaster) SetHeaderForValidator(vData *validatorHeader } // SetValidatorData sets the data for consensus validator delayed broadcast -func (dbb *delayedBlockBroadcaster) SetValidatorData(broadcastData *delayedBroadcastData) error { +func (dbb *delayedBlockBroadcaster) SetValidatorData(broadcastData *shared.DelayedBroadcastData) error { if broadcastData == nil { return spos.ErrNilParameter } alarmIDsToCancel := make([]string, 0) log.Trace("delayedBlockBroadcaster.SetValidatorData: setting validator delay data", - "headerHash", broadcastData.headerHash, - "round", broadcastData.header.GetRound(), - "prevRandSeed", broadcastData.header.GetPrevRandSeed(), + "headerHash", broadcastData.HeaderHash, + "round", broadcastData.Header.GetRound(), + "prevRandSeed", broadcastData.Header.GetPrevRandSeed(), ) dbb.mutDataForBroadcast.Lock() - broadcastData.miniBlockHashes = dbb.extractMiniBlockHashesCrossFromMe(broadcastData.header) + broadcastData.MiniBlockHashes = dbb.extractMiniBlockHashesCrossFromMe(broadcastData.Header) dbb.valBroadcastData = append(dbb.valBroadcastData, broadcastData) if len(dbb.valBroadcastData) > int(dbb.maxValidatorDelayCacheSize) { - alarmHeaderID := prefixHeaderAlarm + hex.EncodeToString(dbb.valBroadcastData[0].headerHash) - alarmDelayID := prefixDelayDataAlarm + hex.EncodeToString(dbb.valBroadcastData[0].headerHash) + alarmHeaderID := prefixHeaderAlarm + hex.EncodeToString(dbb.valBroadcastData[0].HeaderHash) + alarmDelayID := prefixDelayDataAlarm + hex.EncodeToString(dbb.valBroadcastData[0].HeaderHash) alarmIDsToCancel = append(alarmIDsToCancel, alarmHeaderID, alarmDelayID) dbb.valBroadcastData = dbb.valBroadcastData[1:] log.Debug("delayedBlockBroadcaster.SetValidatorData: canceling old alarms (header and delay data) due to too much delay data", - "headerHash", dbb.valBroadcastData[0].headerHash, + "headerHash", dbb.valBroadcastData[0].HeaderHash, "alarmID-header", alarmHeaderID, "alarmID-delay", alarmDelayID, "nbDelayData", len(dbb.valBroadcastData), @@ -251,8 +238,9 @@ func (dbb *delayedBlockBroadcaster) SetBroadcastHandlers( mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, + consensusMessageBroadcast func(message *consensus.Message) error, ) error { - if mbBroadcast == nil || txBroadcast == nil || headerBroadcast == nil { + if mbBroadcast == nil || txBroadcast == nil || headerBroadcast == nil || consensusMessageBroadcast == nil { return spos.ErrNilParameter } @@ -262,6 +250,7 @@ func (dbb *delayedBlockBroadcaster) SetBroadcastHandlers( dbb.broadcastMiniblocksData = mbBroadcast dbb.broadcastTxsData = txBroadcast dbb.broadcastHeader = headerBroadcast + dbb.broadcastConsensusMessage = consensusMessageBroadcast return nil } @@ -319,12 +308,12 @@ func (dbb *delayedBlockBroadcaster) broadcastDataForHeaders(headerHashes [][]byt time.Sleep(common.ExtraDelayForBroadcastBlockInfo) dbb.mutDataForBroadcast.Lock() - dataToBroadcast := make([]*delayedBroadcastData, 0) + dataToBroadcast := make([]*shared.DelayedBroadcastData, 0) OuterLoop: for i := len(dbb.delayedBroadcastData) - 1; i >= 0; i-- { for _, headerHash := range headerHashes { - if bytes.Equal(dbb.delayedBroadcastData[i].headerHash, headerHash) { + if bytes.Equal(dbb.delayedBroadcastData[i].HeaderHash, headerHash) { log.Debug("delayedBlockBroadcaster.broadcastDataForHeaders: leader broadcasts block data", "headerHash", headerHash, ) @@ -366,29 +355,29 @@ func (dbb *delayedBlockBroadcaster) scheduleValidatorBroadcast(dataForValidators log.Trace("delayedBlockBroadcaster.scheduleValidatorBroadcast: registered data for broadcast") for i := range dbb.valBroadcastData { log.Trace("delayedBlockBroadcaster.scheduleValidatorBroadcast", - "round", dbb.valBroadcastData[i].header.GetRound(), - "prevRandSeed", dbb.valBroadcastData[i].header.GetPrevRandSeed(), + "round", dbb.valBroadcastData[i].Header.GetRound(), + "prevRandSeed", dbb.valBroadcastData[i].Header.GetPrevRandSeed(), ) } for _, headerData := range dataForValidators { for _, broadcastData := range dbb.valBroadcastData { - sameRound := headerData.round == broadcastData.header.GetRound() - samePrevRandomness := bytes.Equal(headerData.prevRandSeed, broadcastData.header.GetPrevRandSeed()) + sameRound := headerData.round == broadcastData.Header.GetRound() + samePrevRandomness := bytes.Equal(headerData.prevRandSeed, broadcastData.Header.GetPrevRandSeed()) if sameRound && samePrevRandomness { - duration := validatorDelayPerOrder*time.Duration(broadcastData.order) + common.ExtraDelayForBroadcastBlockInfo - alarmID := prefixDelayDataAlarm + hex.EncodeToString(broadcastData.headerHash) + duration := validatorDelayPerOrder*time.Duration(broadcastData.Order) + common.ExtraDelayForBroadcastBlockInfo + alarmID := prefixDelayDataAlarm + hex.EncodeToString(broadcastData.HeaderHash) alarmsToAdd = append(alarmsToAdd, alarmParams{ id: alarmID, duration: duration, }) log.Trace("delayedBlockBroadcaster.scheduleValidatorBroadcast: scheduling delay data broadcast for notarized header", - "headerHash", broadcastData.headerHash, + "headerHash", broadcastData.HeaderHash, "alarmID", alarmID, "round", headerData.round, "prevRandSeed", headerData.prevRandSeed, - "consensusOrder", broadcastData.order, + "consensusOrder", broadcastData.Order, ) } } @@ -411,9 +400,9 @@ func (dbb *delayedBlockBroadcaster) alarmExpired(alarmID string) { } dbb.mutDataForBroadcast.Lock() - dataToBroadcast := make([]*delayedBroadcastData, 0) + dataToBroadcast := make([]*shared.DelayedBroadcastData, 0) for i, broadcastData := range dbb.valBroadcastData { - if bytes.Equal(broadcastData.headerHash, headerHash) { + if bytes.Equal(broadcastData.HeaderHash, headerHash) { log.Debug("delayedBlockBroadcaster.alarmExpired: validator broadcasts block data (with delay) instead of leader", "headerHash", headerHash, "alarmID", alarmID, @@ -440,9 +429,9 @@ func (dbb *delayedBlockBroadcaster) headerAlarmExpired(alarmID string) { } dbb.mutDataForBroadcast.Lock() - var vHeader *validatorHeaderBroadcastData + var vHeader *shared.ValidatorHeaderBroadcastData for i, broadcastData := range dbb.valHeaderBroadcastData { - if bytes.Equal(broadcastData.headerHash, headerHash) { + if bytes.Equal(broadcastData.HeaderHash, headerHash) { vHeader = broadcastData dbb.valHeaderBroadcastData = append(dbb.valHeaderBroadcastData[:i], dbb.valHeaderBroadcastData[i+1:]...) break @@ -463,7 +452,7 @@ func (dbb *delayedBlockBroadcaster) headerAlarmExpired(alarmID string) { "alarmID", alarmID, ) // broadcast header - err = dbb.broadcastHeader(vHeader.header, vHeader.pkBytes) + err = dbb.broadcastHeader(vHeader.Header, vHeader.PkBytes) if err != nil { log.Warn("delayedBlockBroadcaster.headerAlarmExpired", "error", err.Error(), "headerHash", headerHash, @@ -477,15 +466,15 @@ func (dbb *delayedBlockBroadcaster) headerAlarmExpired(alarmID string) { "headerHash", headerHash, "alarmID", alarmID, ) - go dbb.broadcastBlockData(vHeader.metaMiniBlocksData, vHeader.metaTransactionsData, vHeader.pkBytes, common.ExtraDelayForBroadcastBlockInfo) + go dbb.broadcastBlockData(vHeader.MetaMiniBlocksData, vHeader.MetaTransactionsData, vHeader.PkBytes, common.ExtraDelayForBroadcastBlockInfo) } } -func (dbb *delayedBlockBroadcaster) broadcastDelayedData(broadcastData []*delayedBroadcastData) { +func (dbb *delayedBlockBroadcaster) broadcastDelayedData(broadcastData []*shared.DelayedBroadcastData) { for _, bData := range broadcastData { go func(miniBlocks map[uint32][]byte, transactions map[string][][]byte, pkBytes []byte) { dbb.broadcastBlockData(miniBlocks, transactions, pkBytes, 0) - }(bData.miniBlocksData, bData.transactions, bData.pkBytes) + }(bData.MiniBlocksData, bData.Transactions, bData.PkBytes) } } @@ -596,7 +585,7 @@ func (dbb *delayedBlockBroadcaster) registerInterceptorsCallbackForShard( rootTopic string, cb func(topic string, hash []byte, data interface{}), ) error { - shardIDs := dbb.shardIdentifiers() + shardIDs := common.GetShardIDs(dbb.shardCoordinator.NumberOfShards()) for idx := range shardIDs { // interested only in cross shard data if idx == dbb.shardCoordinator.SelfId() { @@ -614,16 +603,6 @@ func (dbb *delayedBlockBroadcaster) registerInterceptorsCallbackForShard( return nil } -func (dbb *delayedBlockBroadcaster) shardIdentifiers() map[uint32]struct{} { - shardIdentifiers := make(map[uint32]struct{}) - for i := uint32(0); i < dbb.shardCoordinator.NumberOfShards(); i++ { - shardIdentifiers[i] = struct{}{} - } - shardIdentifiers[core.MetachainShardId] = struct{}{} - - return shardIdentifiers -} - func (dbb *delayedBlockBroadcaster) interceptedHeader(_ string, headerHash []byte, header interface{}) { headerHandler, ok := header.(data.HeaderHandler) if !ok { @@ -646,8 +625,8 @@ func (dbb *delayedBlockBroadcaster) interceptedHeader(_ string, headerHash []byt alarmsToCancel := make([]string, 0) dbb.mutDataForBroadcast.Lock() for i, broadcastData := range dbb.valHeaderBroadcastData { - samePrevRandSeed := bytes.Equal(broadcastData.header.GetPrevRandSeed(), headerHandler.GetPrevRandSeed()) - sameRound := broadcastData.header.GetRound() == headerHandler.GetRound() + samePrevRandSeed := bytes.Equal(broadcastData.Header.GetPrevRandSeed(), headerHandler.GetPrevRandSeed()) + sameRound := broadcastData.Header.GetRound() == headerHandler.GetRound() sameHeader := samePrevRandSeed && sameRound if sameHeader { @@ -676,24 +655,24 @@ func (dbb *delayedBlockBroadcaster) interceptedMiniBlockData(topic string, hash "topic", topic, ) - remainingValBroadcastData := make([]*delayedBroadcastData, 0) + remainingValBroadcastData := make([]*shared.DelayedBroadcastData, 0) alarmsToCancel := make([]string, 0) dbb.mutDataForBroadcast.Lock() for i, broadcastData := range dbb.valBroadcastData { - mbHashesMap := broadcastData.miniBlockHashes + mbHashesMap := broadcastData.MiniBlockHashes if len(mbHashesMap) > 0 && len(mbHashesMap[topic]) > 0 { - delete(broadcastData.miniBlockHashes[topic], string(hash)) + delete(broadcastData.MiniBlockHashes[topic], string(hash)) if len(mbHashesMap[topic]) == 0 { delete(mbHashesMap, topic) } } if len(mbHashesMap) == 0 { - alarmID := prefixDelayDataAlarm + hex.EncodeToString(broadcastData.headerHash) + alarmID := prefixDelayDataAlarm + hex.EncodeToString(broadcastData.HeaderHash) alarmsToCancel = append(alarmsToCancel, alarmID) log.Trace("delayedBlockBroadcaster.interceptedMiniBlockData: leader has broadcast block data, validator cancelling alarm", - "headerHash", broadcastData.headerHash, + "headerHash", broadcastData.HeaderHash, "alarmID", alarmID, ) } else { @@ -744,3 +723,8 @@ func (dbb *delayedBlockBroadcaster) extractMbsFromMeTo(header data.HeaderHandler return mbHashesForShard } + +// IsInterfaceNil returns true if there is no value under the interface +func (dbb *delayedBlockBroadcaster) IsInterfaceNil() bool { + return dbb == nil +} diff --git a/consensus/broadcast/delayedBroadcast_test.go b/consensus/broadcast/delayedBroadcast_test.go index 0f22e8a5157..961bc0efcc9 100644 --- a/consensus/broadcast/delayedBroadcast_test.go +++ b/consensus/broadcast/delayedBroadcast_test.go @@ -1,8 +1,10 @@ package broadcast_test import ( + "bytes" "encoding/hex" "errors" + "fmt" "strconv" "sync" "testing" @@ -13,14 +15,18 @@ import ( "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/broadcast" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/testscommon/pool" ) type validatorDelayArgs struct { @@ -34,23 +40,51 @@ type validatorDelayArgs struct { order uint32 } +type syncLogObserver struct { + sync.Mutex + buffer *bytes.Buffer +} + +// Write method that locks the mutex before writing +func (o *syncLogObserver) Write(p []byte) (n int, err error) { + o.Lock() + defer o.Unlock() + return o.buffer.Write(p) +} + +func (o *syncLogObserver) getBufferStr() string { + o.Lock() + logOutputStr := o.buffer.String() + o.Unlock() + + return logOutputStr +} + +func createLogsObserver() *syncLogObserver { + return &syncLogObserver{ + buffer: &bytes.Buffer{}, + } +} + func createValidatorDelayArgs(index int) *validatorDelayArgs { iStr := strconv.Itoa(index) return &validatorDelayArgs{ headerHash: []byte("header hash" + iStr), - header: &block.Header{ - PrevRandSeed: []byte("prev rand seed" + iStr), - Round: uint64(0), - MiniBlockHeaders: []block.MiniBlockHeader{ - { - Hash: []byte("miniBlockHash0" + iStr), - SenderShardID: 0, - ReceiverShardID: 0, - }, - { - Hash: []byte("miniBlockHash1" + iStr), - SenderShardID: 0, - ReceiverShardID: 1, + header: &block.HeaderV2{ + Header: &block.Header{ + PrevRandSeed: []byte("prev rand seed" + iStr), + Round: uint64(0), + MiniBlockHeaders: []block.MiniBlockHeader{ + { + Hash: []byte("miniBlockHash0" + iStr), + SenderShardID: 0, + ReceiverShardID: 0, + }, + { + Hash: []byte("miniBlockHash1" + iStr), + SenderShardID: 0, + ReceiverShardID: 1, + }, }, }, }, @@ -97,7 +131,7 @@ func createMetaBlock() *block.MetaBlock { } func createDefaultDelayedBroadcasterArgs() *broadcast.ArgsDelayedBlockBroadcaster { - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptorsContainer := createInterceptorContainer() dbbArgs := &broadcast.ArgsDelayedBlockBroadcaster{ ShardCoordinator: &mock.ShardCoordinatorMock{}, @@ -177,12 +211,15 @@ func TestDelayedBlockBroadcaster_HeaderReceivedNoDelayedDataRegistered(t *testin broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) metaBlock := createMetaBlock() @@ -210,12 +247,15 @@ func TestDelayedBlockBroadcaster_HeaderReceivedForRegisteredDelayedDataShouldBro broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) headerHash, _, miniblocksData, transactionsData := createDelayData("1") @@ -256,12 +296,15 @@ func TestDelayedBlockBroadcaster_HeaderReceivedForNotRegisteredDelayedDataShould broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) headerHash, _, miniblocksData, transactionsData := createDelayData("1") @@ -284,6 +327,74 @@ func TestDelayedBlockBroadcaster_HeaderReceivedForNotRegisteredDelayedDataShould assert.False(t, txBroadcastCalled.IsSet()) } +func TestDelayedBlockBroadcaster_HeaderReceivedWithoutSignaturesForShardShouldNotBroadcastTheData(t *testing.T) { + observer := createLogsObserver() + err := logger.AddLogObserver(observer, &logger.PlainFormatter{}) + require.Nil(t, err) + + originalLogPattern := logger.GetLogLevelPattern() + err = logger.SetLogLevel("*:TRACE") + require.Nil(t, err) + + defer func() { + err = logger.RemoveLogObserver(observer) + require.Nil(t, err) + err = logger.SetLogLevel(originalLogPattern) + require.Nil(t, err) + }() + + mbBroadcastCalled := atomic.Flag{} + txBroadcastCalled := atomic.Flag{} + + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { + mbBroadcastCalled.SetValue(true) + return nil + } + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { + txBroadcastCalled.SetValue(true) + return nil + } + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { + return nil + } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + + delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() + dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) + require.Nil(t, err) + + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) + require.Nil(t, err) + + headerHash, _, miniblocksData, transactionsData := createDelayData("1") + delayedData := broadcast.CreateDelayBroadcastDataForLeader(headerHash, miniblocksData, transactionsData) + err = dbb.SetLeaderData(delayedData) + + metaBlock := createMetaBlock() + metaBlock.ShardInfo = []block.ShardData{} + + assert.Nil(t, err) + time.Sleep(10 * time.Millisecond) + assert.False(t, mbBroadcastCalled.IsSet()) + assert.False(t, txBroadcastCalled.IsSet()) + + dbb.HeaderReceived(metaBlock, []byte("meta hash")) + sleepTime := common.ExtraDelayForBroadcastBlockInfo + + common.ExtraDelayBetweenBroadcastMbsAndTxs + + 100*time.Millisecond + time.Sleep(sleepTime) + + logOutputStr := observer.getBufferStr() + expectedLogMsg := "delayedBlockBroadcaster.headerReceived: header received with no shardData for current shard" + require.Contains(t, logOutputStr, expectedLogMsg) + require.Contains(t, logOutputStr, fmt.Sprintf("headerHash = %s", hex.EncodeToString(headerHash))) + + assert.False(t, mbBroadcastCalled.IsSet()) + assert.False(t, txBroadcastCalled.IsSet()) +} + func TestDelayedBlockBroadcaster_HeaderReceivedForNextRegisteredDelayedDataShouldBroadcastBoth(t *testing.T) { t.Parallel() @@ -301,12 +412,15 @@ func TestDelayedBlockBroadcaster_HeaderReceivedForNextRegisteredDelayedDataShoul broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) headerHash, _, miniblocksData, transactionsData := createDelayData("1") @@ -368,6 +482,72 @@ func TestDelayedBlockBroadcaster_SetLeaderData(t *testing.T) { require.Equal(t, 1, len(vbb)) } +func TestDelayedBlockBroadcaster_SetLeaderDataOverCacheSizeShouldBroadcastOldest(t *testing.T) { + observer := createLogsObserver() + err := logger.AddLogObserver(observer, &logger.PlainFormatter{}) + require.Nil(t, err) + + originalLogPattern := logger.GetLogLevelPattern() + err = logger.SetLogLevel("*:DEBUG") + require.Nil(t, err) + + defer func() { + err = logger.RemoveLogObserver(observer) + require.Nil(t, err) + err = logger.SetLogLevel(originalLogPattern) + require.Nil(t, err) + }() + + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { + return nil + } + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { + return nil + } + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { + return nil + } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + + delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() + dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) + require.Nil(t, err) + + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) + require.Nil(t, err) + + headerHash1, _, miniBlockData1, transactionsData1 := createDelayData("1") + delayedData1 := broadcast.CreateDelayBroadcastDataForLeader(headerHash1, miniBlockData1, transactionsData1) + err = dbb.SetLeaderData(delayedData1) + require.Nil(t, err) + time.Sleep(10 * time.Millisecond) + + headerHash2, _, miniBlockData2, transactionsData2 := createDelayData("2") + delayedData2 := broadcast.CreateDelayBroadcastDataForLeader(headerHash2, miniBlockData2, transactionsData2) + err = dbb.SetLeaderData(delayedData2) + require.Nil(t, err) + time.Sleep(10 * time.Millisecond) + + // should trigger the log message + headerHash3, _, miniBlockData3, transactionsData3 := createDelayData("3") + delayedData3 := broadcast.CreateDelayBroadcastDataForLeader(headerHash3, miniBlockData3, transactionsData3) + err = dbb.SetLeaderData(delayedData3) + require.Nil(t, err) + time.Sleep(10 * time.Millisecond) + + logOutputStr := observer.getBufferStr() + expectedLogMsg := "delayedBlockBroadcaster.SetLeaderData: leader broadcasts old data before alarm due to too much delay data" + require.Contains(t, logOutputStr, expectedLogMsg) + require.Contains(t, logOutputStr, fmt.Sprintf("headerHash = %s", hex.EncodeToString(headerHash1))) + require.Contains(t, logOutputStr, "nbDelayedData = 3") + require.Contains(t, logOutputStr, "maxDelayCacheSize = 2") + + vbb := dbb.GetLeaderBroadcastData() + require.Equal(t, 2, len(vbb)) +} + func TestDelayedBlockBroadcaster_SetValidatorDataNilDataShouldErr(t *testing.T) { t.Parallel() @@ -405,6 +585,88 @@ func TestDelayedBlockBroadcaster_SetValidatorData(t *testing.T) { require.Equal(t, 1, len(vbb)) } +func TestDelayedBlockBroadcaster_SetBroadcastHandlersFailsIfNilHandler(t *testing.T) { + t.Parallel() + + delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() + dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) + require.Nil(t, err) + + err = dbb.SetBroadcastHandlers(nil, nil, nil, nil) + require.Equal(t, spos.ErrNilParameter, err) +} + +func TestDelayedBlockBroadcaster_SetHeaderForValidatorWithoutSignaturesShouldNotSetAlarm(t *testing.T) { + observer := createLogsObserver() + err := logger.AddLogObserver(observer, &logger.PlainFormatter{}) + require.Nil(t, err) + + originalLogPattern := logger.GetLogLevelPattern() + err = logger.SetLogLevel("*:TRACE") + require.Nil(t, err) + + defer func() { + err = logger.RemoveLogObserver(observer) + require.Nil(t, err) + err = logger.SetLogLevel(originalLogPattern) + require.Nil(t, err) + }() + + mbBroadcastCalled := atomic.Counter{} + txBroadcastCalled := atomic.Counter{} + headerBroadcastCalled := atomic.Counter{} + + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { + mbBroadcastCalled.Increment() + return nil + } + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { + txBroadcastCalled.Increment() + return nil + } + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { + headerBroadcastCalled.Increment() + return nil + } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + + delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() + dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) + require.Nil(t, err) + + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) + require.Nil(t, err) + + vArgs := createValidatorDelayArgs(0) + + valHeaderData := broadcast.CreateValidatorHeaderBroadcastData( + vArgs.headerHash, + vArgs.header, + vArgs.metaMiniBlocks, + vArgs.metaTransactions, + vArgs.order, + ) + err = dbb.SetHeaderForValidator(valHeaderData) + require.Nil(t, err) + + logOutputStr := observer.getBufferStr() + expectedLogMsg := "delayedBlockBroadcaster.SetHeaderForValidator: header alarm has not been set" + require.Contains(t, logOutputStr, expectedLogMsg) + require.Contains(t, logOutputStr, fmt.Sprintf("validatorConsensusOrder = %d", vArgs.order)) + + vbb := dbb.GetValidatorHeaderBroadcastData() + require.Equal(t, 0, len(vbb)) + + sleepTime := broadcast.ValidatorDelayPerOrder()*time.Duration(vArgs.order) + + time.Millisecond*100 + time.Sleep(sleepTime) + + vbb = dbb.GetValidatorHeaderBroadcastData() + require.Equal(t, 0, len(vbb)) +} + func TestDelayedBlockBroadcaster_SetHeaderForValidatorShouldSetAlarmAndBroadcastHeader(t *testing.T) { t.Parallel() @@ -424,12 +686,15 @@ func TestDelayedBlockBroadcaster_SetHeaderForValidatorShouldSetAlarmAndBroadcast headerBroadcastCalled.Increment() return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -465,6 +730,77 @@ func TestDelayedBlockBroadcaster_SetHeaderForValidatorShouldSetAlarmAndBroadcast require.Equal(t, 0, len(vbb)) } +func TestDelayedBlockBroadcaster_SetHeaderForValidator_BroadcastHeaderError(t *testing.T) { + observer := createLogsObserver() + err := logger.AddLogObserver(observer, &logger.PlainFormatter{}) + require.Nil(t, err) + + defer func() { + err = logger.RemoveLogObserver(observer) + require.Nil(t, err) + }() + + mbBroadcastCalled := atomic.Counter{} + txBroadcastCalled := atomic.Counter{} + + broadcastError := "broadcast error" + + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { + mbBroadcastCalled.Increment() + return nil + } + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { + txBroadcastCalled.Increment() + return nil + } + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { + return errors.New(broadcastError) + } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + + delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() + dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) + require.Nil(t, err) + + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) + require.Nil(t, err) + + vArgs := createValidatorDelayArgs(0) + err = vArgs.header.SetSignature([]byte("agg sig")) + require.Nil(t, err) + + valHeaderData := broadcast.CreateValidatorHeaderBroadcastData( + vArgs.headerHash, + vArgs.header, + vArgs.metaMiniBlocks, + vArgs.metaTransactions, + vArgs.order, + ) + err = dbb.SetHeaderForValidator(valHeaderData) + require.Nil(t, err) + + vbb := dbb.GetValidatorHeaderBroadcastData() + require.Equal(t, 1, len(vbb)) + require.Equal(t, int64(0), mbBroadcastCalled.Get()) + require.Equal(t, int64(0), txBroadcastCalled.Get()) + + sleepTime := broadcast.ValidatorDelayPerOrder()*time.Duration(vArgs.order) + + time.Millisecond*100 + time.Sleep(sleepTime) + + logOutputStr := observer.getBufferStr() + expectedLogMsg := "delayedBlockBroadcaster.headerAlarmExpired error = %s" + require.Contains(t, logOutputStr, fmt.Sprintf(expectedLogMsg, broadcastError)) + + require.Equal(t, int64(0), mbBroadcastCalled.Get()) + require.Equal(t, int64(0), txBroadcastCalled.Get()) + + vbb = dbb.GetValidatorHeaderBroadcastData() + require.Equal(t, 0, len(vbb)) +} + func TestDelayedBlockBroadcaster_SetValidatorDataFinalizedMetaHeaderShouldSetAlarmAndBroadcastHeaderAndData(t *testing.T) { t.Parallel() @@ -484,6 +820,9 @@ func TestDelayedBlockBroadcaster_SetValidatorDataFinalizedMetaHeaderShouldSetAla headerBroadcastCalled.Increment() return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() delayBroadcasterArgs.ShardCoordinator = mock.ShardCoordinatorMock{ @@ -492,7 +831,7 @@ func TestDelayedBlockBroadcaster_SetValidatorDataFinalizedMetaHeaderShouldSetAla dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -552,6 +891,9 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderShouldCancelAlarm(t *testing.T headerBroadcastCalled.Increment() return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() delayBroadcasterArgs.ShardCoordinator = mock.ShardCoordinatorMock{ @@ -560,7 +902,7 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderShouldCancelAlarm(t *testing.T dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -621,6 +963,9 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderShouldCancelAlarmForHeaderBroa headerBroadcastCalled.Increment() return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() delayBroadcasterArgs.ShardCoordinator = mock.ShardCoordinatorMock{ @@ -629,7 +974,7 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderShouldCancelAlarmForHeaderBroa dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -689,6 +1034,9 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderInvalidOrDifferentShouldIgnore headerBroadcastCalled.Increment() return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() delayBroadcasterArgs.ShardCoordinator = mock.ShardCoordinatorMock{ @@ -697,7 +1045,7 @@ func TestDelayedBlockBroadcaster_InterceptedHeaderInvalidOrDifferentShouldIgnore dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -802,12 +1150,15 @@ func TestDelayedBlockBroadcaster_ScheduleValidatorBroadcastDifferentHeaderRoundS broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -859,12 +1210,15 @@ func TestDelayedBlockBroadcaster_ScheduleValidatorBroadcastDifferentPrevRandShou broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -919,12 +1273,15 @@ func TestDelayedBlockBroadcaster_ScheduleValidatorBroadcastSameRoundAndPrevRandS broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -979,12 +1336,15 @@ func TestDelayedBlockBroadcaster_AlarmExpiredShouldBroadcastTheDataForRegistered broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -1032,12 +1392,15 @@ func TestDelayedBlockBroadcaster_AlarmExpiredShouldDoNothingForNotRegisteredData broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -1069,6 +1432,60 @@ func TestDelayedBlockBroadcaster_AlarmExpiredShouldDoNothingForNotRegisteredData require.Equal(t, 1, len(vbd)) } +func TestDelayedBlockBroadcaster_HeaderAlarmExpired_InvalidAlarmID(t *testing.T) { + observer := createLogsObserver() + err := logger.AddLogObserver(observer, &logger.PlainFormatter{}) + require.Nil(t, err) + + defer func() { + err = logger.RemoveLogObserver(observer) + require.Nil(t, err) + }() + + delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() + dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) + require.Nil(t, err) + + invalidAlarmID := "invalid_alarm_id" + dbb.HeaderAlarmExpired(invalidAlarmID) + + logOutputStr := observer.getBufferStr() + expectedLogMsg := "delayedBlockBroadcaster.headerAlarmExpired" + require.Contains(t, logOutputStr, expectedLogMsg) + require.Contains(t, logOutputStr, fmt.Sprintf("alarmID = %s", invalidAlarmID)) +} + +func TestDelayedBlockBroadcaster_HeaderAlarmExpired_HeaderDataNil(t *testing.T) { + observer := createLogsObserver() + err := logger.AddLogObserver(observer, &logger.PlainFormatter{}) + require.Nil(t, err) + + originalLogPattern := logger.GetLogLevelPattern() + err = logger.SetLogLevel("*:DEBUG") + require.Nil(t, err) + + defer func() { + err = logger.RemoveLogObserver(observer) + require.Nil(t, err) + err = logger.SetLogLevel(originalLogPattern) + require.Nil(t, err) + }() + + delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() + dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) + require.Nil(t, err) + + invalidHeaderHash := []byte("invalid_header_hash") + alarmID := "header_" + hex.EncodeToString(invalidHeaderHash) + + dbb.HeaderAlarmExpired(alarmID) + + logOutputStr := observer.getBufferStr() + expectedLogMsg := "delayedBlockBroadcaster.headerAlarmExpired: alarm data is nil" + require.Contains(t, logOutputStr, expectedLogMsg) + require.Contains(t, logOutputStr, "alarmID = "+alarmID) +} + func TestDelayedBlockBroadcaster_RegisterInterceptorCallback(t *testing.T) { delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() var cbsHeader []func(topic string, hash []byte, data interface{}) @@ -1139,6 +1556,53 @@ func TestDelayedBlockBroadcaster_RegisterInterceptorCallback(t *testing.T) { require.Equal(t, 2, nbRegisteredMbsHandlers) } +func TestDelayedBlockBroadcaster_BroadcastBlockDataFailedBroadcast(t *testing.T) { + observer := createLogsObserver() + err := logger.AddLogObserver(observer, &logger.PlainFormatter{}) + require.Nil(t, err) + + defer func() { + err = logger.RemoveLogObserver(observer) + require.Nil(t, err) + }() + + errMiniBlocks := "mini blocks broadcast error" + broadcastMiniBlocks := func(mbData map[uint32][]byte, pk []byte) error { + return errors.New(errMiniBlocks) + } + errTxs := "transactions broadcast error" + broadcastTransactions := func(txData map[string][][]byte, pk []byte) error { + return errors.New(errTxs) + } + broadcastHeader := func(header data.HeaderHandler, pk []byte) error { + return nil + } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } + + delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() + dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) + require.Nil(t, err) + + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) + require.Nil(t, err) + + dbb.BroadcastBlockData(nil, nil, nil, time.Millisecond*100) + + logOutputStr := observer.getBufferStr() + require.Contains(t, logOutputStr, errMiniBlocks) + require.Contains(t, logOutputStr, errTxs) +} + +func TestDelayedBlockBroadcaster_GetShardDataFromMetaChainBlockInvalidMetaHandler(t *testing.T) { + shardID := uint32(0) + + _, _, err := broadcast.GetShardDataFromMetaChainBlock(nil, shardID) + require.NotNil(t, err) + require.Equal(t, spos.ErrInvalidMetaHeader, err) +} + func TestDelayedBlockBroadcaster_GetShardDataFromMetaChainBlock(t *testing.T) { metaHeader := createMetaBlock() shardID := uint32(0) @@ -1180,12 +1644,15 @@ func TestDelayedBlockBroadcaster_InterceptedMiniBlockForNotSetValDataShouldBroad broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -1243,12 +1710,15 @@ func TestDelayedBlockBroadcaster_InterceptedMiniBlockOutOfManyForSetValDataShoul broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -1307,12 +1777,15 @@ func TestDelayedBlockBroadcaster_InterceptedMiniBlockFinalForSetValDataShouldNot broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) @@ -1371,12 +1844,15 @@ func TestDelayedBlockBroadcaster_Close(t *testing.T) { broadcastHeader := func(header data.HeaderHandler, pk []byte) error { return nil } + broadcastConsensusMessage := func(message *consensus.Message) error { + return nil + } delayBroadcasterArgs := createDefaultDelayedBroadcasterArgs() dbb, err := broadcast.NewDelayedBlockBroadcaster(delayBroadcasterArgs) require.Nil(t, err) - err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader) + err = dbb.SetBroadcastHandlers(broadcastMiniBlocks, broadcastTransactions, broadcastHeader, broadcastConsensusMessage) require.Nil(t, err) vArgs := createValidatorDelayArgs(0) diff --git a/consensus/broadcast/errors.go b/consensus/broadcast/errors.go index 86acef6937b..c16c878bc50 100644 --- a/consensus/broadcast/errors.go +++ b/consensus/broadcast/errors.go @@ -4,3 +4,6 @@ import "errors" // ErrNilKeysHandler signals that a nil keys handler was provided var ErrNilKeysHandler = errors.New("nil keys handler") + +// ErrNilDelayedBroadcaster signals that a nil delayed broadcaster was provided +var ErrNilDelayedBroadcaster = errors.New("nil delayed broadcaster") diff --git a/consensus/broadcast/export.go b/consensus/broadcast/export.go index e7b0e4dfa80..5351003e38f 100644 --- a/consensus/broadcast/export.go +++ b/consensus/broadcast/export.go @@ -6,7 +6,9 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" "github.com/multiversx/mx-chain-go/sharding" ) @@ -32,14 +34,14 @@ func CreateDelayBroadcastDataForValidator( miniBlockHashes map[string]map[string]struct{}, transactionsData map[string][][]byte, order uint32, -) *delayedBroadcastData { - return &delayedBroadcastData{ - headerHash: headerHash, - header: header, - miniBlocksData: miniblocksData, - miniBlockHashes: miniBlockHashes, - transactions: transactionsData, - order: order, +) *shared.DelayedBroadcastData { + return &shared.DelayedBroadcastData{ + HeaderHash: headerHash, + Header: header, + MiniBlocksData: miniblocksData, + MiniBlockHashes: miniBlockHashes, + Transactions: transactionsData, + Order: order, } } @@ -50,13 +52,13 @@ func CreateValidatorHeaderBroadcastData( metaMiniBlocksData map[uint32][]byte, metaTransactionsData map[string][][]byte, order uint32, -) *validatorHeaderBroadcastData { - return &validatorHeaderBroadcastData{ - headerHash: headerHash, - header: header, - metaMiniBlocksData: metaMiniBlocksData, - metaTransactionsData: metaTransactionsData, - order: order, +) *shared.ValidatorHeaderBroadcastData { + return &shared.ValidatorHeaderBroadcastData{ + HeaderHash: headerHash, + Header: header, + MetaMiniBlocksData: metaMiniBlocksData, + MetaTransactionsData: metaTransactionsData, + Order: order, } } @@ -65,11 +67,11 @@ func CreateDelayBroadcastDataForLeader( headerHash []byte, miniblocks map[uint32][]byte, transactions map[string][][]byte, -) *delayedBroadcastData { - return &delayedBroadcastData{ - headerHash: headerHash, - miniBlocksData: miniblocks, - transactions: transactions, +) *shared.DelayedBroadcastData { + return &shared.DelayedBroadcastData{ + HeaderHash: headerHash, + MiniBlocksData: miniblocks, + Transactions: transactions, } } @@ -80,9 +82,9 @@ func (dbb *delayedBlockBroadcaster) HeaderReceived(headerHandler data.HeaderHand } // GetValidatorBroadcastData returns the set validator delayed broadcast data -func (dbb *delayedBlockBroadcaster) GetValidatorBroadcastData() []*delayedBroadcastData { +func (dbb *delayedBlockBroadcaster) GetValidatorBroadcastData() []*shared.DelayedBroadcastData { dbb.mutDataForBroadcast.RLock() - copyValBroadcastData := make([]*delayedBroadcastData, len(dbb.valBroadcastData)) + copyValBroadcastData := make([]*shared.DelayedBroadcastData, len(dbb.valBroadcastData)) copy(copyValBroadcastData, dbb.valBroadcastData) dbb.mutDataForBroadcast.RUnlock() @@ -90,9 +92,9 @@ func (dbb *delayedBlockBroadcaster) GetValidatorBroadcastData() []*delayedBroadc } // GetValidatorHeaderBroadcastData - -func (dbb *delayedBlockBroadcaster) GetValidatorHeaderBroadcastData() []*validatorHeaderBroadcastData { +func (dbb *delayedBlockBroadcaster) GetValidatorHeaderBroadcastData() []*shared.ValidatorHeaderBroadcastData { dbb.mutDataForBroadcast.RLock() - copyValHeaderBroadcastData := make([]*validatorHeaderBroadcastData, len(dbb.valHeaderBroadcastData)) + copyValHeaderBroadcastData := make([]*shared.ValidatorHeaderBroadcastData, len(dbb.valHeaderBroadcastData)) copy(copyValHeaderBroadcastData, dbb.valHeaderBroadcastData) dbb.mutDataForBroadcast.RUnlock() @@ -100,9 +102,9 @@ func (dbb *delayedBlockBroadcaster) GetValidatorHeaderBroadcastData() []*validat } // GetLeaderBroadcastData returns the set leader delayed broadcast data -func (dbb *delayedBlockBroadcaster) GetLeaderBroadcastData() []*delayedBroadcastData { +func (dbb *delayedBlockBroadcaster) GetLeaderBroadcastData() []*shared.DelayedBroadcastData { dbb.mutDataForBroadcast.RLock() - copyDelayBroadcastData := make([]*delayedBroadcastData, len(dbb.delayedBroadcastData)) + copyDelayBroadcastData := make([]*shared.DelayedBroadcastData, len(dbb.delayedBroadcastData)) copy(copyDelayBroadcastData, dbb.delayedBroadcastData) dbb.mutDataForBroadcast.RUnlock() @@ -132,6 +134,11 @@ func (dbb *delayedBlockBroadcaster) AlarmExpired(headerHash string) { dbb.alarmExpired(headerHash) } +// HeaderAlarmExpired - +func (dbb *delayedBlockBroadcaster) HeaderAlarmExpired(headerHash string) { + dbb.headerAlarmExpired(headerHash) +} + // GetShardDataFromMetaChainBlock - func GetShardDataFromMetaChainBlock( headerHandler data.HeaderHandler, @@ -166,6 +173,16 @@ func (dbb *delayedBlockBroadcaster) InterceptedHeaderData(topic string, hash []b dbb.interceptedHeader(topic, hash, header) } +// BroadcastBlockData - +func (dbb *delayedBlockBroadcaster) BroadcastBlockData( + miniBlocks map[uint32][]byte, + transactions map[string][][]byte, + pkBytes []byte, + delay time.Duration, +) { + dbb.broadcastBlockData(miniBlocks, transactions, pkBytes, delay) +} + // NewCommonMessenger will return a new instance of a commonMessenger func NewCommonMessenger( marshalizer marshal.Marshalizer, diff --git a/consensus/broadcast/export_test.go b/consensus/broadcast/export_test.go new file mode 100644 index 00000000000..646dfa9b161 --- /dev/null +++ b/consensus/broadcast/export_test.go @@ -0,0 +1,12 @@ +package broadcast + +import ( + "github.com/multiversx/mx-chain-core-go/marshal" +) + +// SetMarshalizerMeta sets the unexported marshaller +func (mcm *metaChainMessenger) SetMarshalizerMeta( + m marshal.Marshalizer, +) { + mcm.marshalizer = m +} diff --git a/consensus/broadcast/interface.go b/consensus/broadcast/interface.go new file mode 100644 index 00000000000..d453f7708d9 --- /dev/null +++ b/consensus/broadcast/interface.go @@ -0,0 +1,22 @@ +package broadcast + +import ( + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" +) + +// DelayedBroadcaster exposes functionality for handling the consensus members broadcasting of delay data +type DelayedBroadcaster interface { + SetLeaderData(data *shared.DelayedBroadcastData) error + SetValidatorData(data *shared.DelayedBroadcastData) error + SetHeaderForValidator(vData *shared.ValidatorHeaderBroadcastData) error + SetBroadcastHandlers( + mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, + txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, + headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, + consensusMessageBroadcast func(message *consensus.Message) error, + ) error + Close() + IsInterfaceNil() bool +} diff --git a/consensus/broadcast/metaChainMessenger.go b/consensus/broadcast/metaChainMessenger.go index daca3b436a5..b495719a13d 100644 --- a/consensus/broadcast/metaChainMessenger.go +++ b/consensus/broadcast/metaChainMessenger.go @@ -5,8 +5,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/process/factory" ) @@ -32,27 +34,13 @@ func NewMetaChainMessenger( return nil, err } - dbbArgs := &ArgsDelayedBlockBroadcaster{ - InterceptorsContainer: args.InterceptorsContainer, - HeadersSubscriber: args.HeadersSubscriber, - LeaderCacheSize: args.MaxDelayCacheSize, - ValidatorCacheSize: args.MaxValidatorDelayCacheSize, - ShardCoordinator: args.ShardCoordinator, - AlarmScheduler: args.AlarmScheduler, - } - - dbb, err := NewDelayedBlockBroadcaster(dbbArgs) - if err != nil { - return nil, err - } - cm := &commonMessenger{ marshalizer: args.Marshalizer, hasher: args.Hasher, messenger: args.Messenger, shardCoordinator: args.ShardCoordinator, peerSignatureHandler: args.PeerSignatureHandler, - delayedBlockBroadcaster: dbb, + delayedBlockBroadcaster: args.DelayedBroadcaster, keysHandler: args.KeysHandler, } @@ -60,7 +48,11 @@ func NewMetaChainMessenger( commonMessenger: cm, } - err = dbb.SetBroadcastHandlers(mcm.BroadcastMiniBlocks, mcm.BroadcastTransactions, mcm.BroadcastHeader) + err = mcm.delayedBlockBroadcaster.SetBroadcastHandlers( + mcm.BroadcastMiniBlocks, + mcm.BroadcastTransactions, + mcm.BroadcastHeader, + mcm.BroadcastConsensusMessage) if err != nil { return nil, err } @@ -124,6 +116,14 @@ func (mcm *metaChainMessenger) BroadcastHeader(header data.HeaderHandler, pkByte return nil } +// BroadcastEquivalentProof will broadcast the proof for a header on the metachain common topic +func (mcm *metaChainMessenger) BroadcastEquivalentProof(proof data.HeaderProofHandler, pkBytes []byte) error { + identifierMetaAll := mcm.shardCoordinator.CommunicationIdentifier(core.AllShardId) + topic := common.EquivalentProofsTopic + identifierMetaAll + + return mcm.broadcastEquivalentProof(proof, pkBytes, topic) +} + // BroadcastBlockDataLeader broadcasts the block data as consensus group leader func (mcm *metaChainMessenger) BroadcastBlockDataLeader( _ data.HeaderHandler, @@ -154,13 +154,13 @@ func (mcm *metaChainMessenger) PrepareBroadcastHeaderValidator( return } - vData := &validatorHeaderBroadcastData{ - headerHash: headerHash, - header: header, - metaMiniBlocksData: miniBlocks, - metaTransactionsData: transactions, - order: uint32(idx), - pkBytes: pkBytes, + vData := &shared.ValidatorHeaderBroadcastData{ + HeaderHash: headerHash, + Header: header, + MetaMiniBlocksData: miniBlocks, + MetaTransactionsData: transactions, + Order: uint32(idx), + PkBytes: pkBytes, } err = mcm.delayedBlockBroadcaster.SetHeaderForValidator(vData) @@ -180,6 +180,16 @@ func (mcm *metaChainMessenger) PrepareBroadcastBlockDataValidator( ) { } +// PrepareBroadcastBlockDataWithEquivalentProofs prepares the broadcast of block data with equivalent proofs +func (mcm *metaChainMessenger) PrepareBroadcastBlockDataWithEquivalentProofs( + _ data.HeaderHandler, + miniBlocks map[uint32][]byte, + transactions map[string][][]byte, + pkBytes []byte, +) { + go mcm.BroadcastBlockData(miniBlocks, transactions, pkBytes, common.ExtraDelayForBroadcastBlockInfo) +} + // Close closes all the started infinite looping goroutines and subcomponents func (mcm *metaChainMessenger) Close() { mcm.delayedBlockBroadcaster.Close() diff --git a/consensus/broadcast/metaChainMessenger_test.go b/consensus/broadcast/metaChainMessenger_test.go index 01cbb6a151d..7f188e48c59 100644 --- a/consensus/broadcast/metaChainMessenger_test.go +++ b/consensus/broadcast/metaChainMessenger_test.go @@ -7,16 +7,22 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/broadcast" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/testscommon" + consensusMock "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/testscommon/pool" ) var nodePkBytes = []byte("node public key bytes") @@ -27,10 +33,11 @@ func createDefaultMetaChainArgs() broadcast.MetaChainMessengerArgs { shardCoordinatorMock := &mock.ShardCoordinatorMock{} singleSignerMock := &mock.SingleSignerMock{} hasher := &hashingMocks.HasherMock{} - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptorsContainer := createInterceptorContainer() peerSigHandler := &mock.PeerSignatureHandler{Signer: singleSignerMock} - alarmScheduler := &mock.AlarmSchedulerStub{} + alarmScheduler := &testscommon.AlarmSchedulerStub{} + delayedBroadcaster := &consensusMock.DelayedBroadcasterMock{} return broadcast.MetaChainMessengerArgs{ CommonMessengerArgs: broadcast.CommonMessengerArgs{ @@ -45,6 +52,7 @@ func createDefaultMetaChainArgs() broadcast.MetaChainMessengerArgs { MaxDelayCacheSize: 2, AlarmScheduler: alarmScheduler, KeysHandler: &testscommon.KeysHandlerStub{}, + DelayedBroadcaster: delayedBroadcaster, }, } } @@ -94,6 +102,14 @@ func TestMetaChainMessenger_NilKeysHandlerShouldError(t *testing.T) { assert.Equal(t, broadcast.ErrNilKeysHandler, err) } +func TestMetaChainMessenger_NilDelayedBroadcasterShouldError(t *testing.T) { + args := createDefaultMetaChainArgs() + args.DelayedBroadcaster = nil + scm, err := broadcast.NewMetaChainMessenger(args) + + assert.Nil(t, scm) + assert.Equal(t, broadcast.ErrNilDelayedBroadcaster, err) +} func TestMetaChainMessenger_NewMetaChainMessengerShouldWork(t *testing.T) { args := createDefaultMetaChainArgs() mcm, err := broadcast.NewMetaChainMessenger(args) @@ -292,3 +308,114 @@ func TestMetaChainMessenger_BroadcastBlockDataLeader(t *testing.T) { assert.Equal(t, len(transactions), numBroadcast) }) } + +func TestMetaChainMessenger_Close(t *testing.T) { + t.Parallel() + + args := createDefaultMetaChainArgs() + closeCalled := false + delayedBroadcaster := &consensusMock.DelayedBroadcasterMock{ + CloseCalled: func() { + closeCalled = true + }, + } + args.DelayedBroadcaster = delayedBroadcaster + + mcm, _ := broadcast.NewMetaChainMessenger(args) + require.NotNil(t, mcm) + mcm.Close() + assert.True(t, closeCalled) +} + +func TestMetaChainMessenger_PrepareBroadcastHeaderValidator(t *testing.T) { + t.Parallel() + + t.Run("Nil header", func(t *testing.T) { + t.Parallel() + + args := createDefaultMetaChainArgs() + delayedBroadcaster := &consensusMock.DelayedBroadcasterMock{ + SetHeaderForValidatorCalled: func(vData *shared.ValidatorHeaderBroadcastData) error { + require.Fail(t, "SetHeaderForValidator should not be called") + return nil + }, + } + args.DelayedBroadcaster = delayedBroadcaster + + mcm, _ := broadcast.NewMetaChainMessenger(args) + require.NotNil(t, mcm) + mcm.PrepareBroadcastHeaderValidator(nil, make(map[uint32][]byte), make(map[string][][]byte), 0, make([]byte, 0)) + }) + t.Run("Err on core.CalculateHash", func(t *testing.T) { + t.Parallel() + + args := createDefaultMetaChainArgs() + delayedBroadcaster := &consensusMock.DelayedBroadcasterMock{ + SetHeaderForValidatorCalled: func(vData *shared.ValidatorHeaderBroadcastData) error { + require.Fail(t, "SetHeaderForValidator should not be called") + return nil + }, + } + args.DelayedBroadcaster = delayedBroadcaster + + header := &block.Header{} + mcm, _ := broadcast.NewMetaChainMessenger(args) + require.NotNil(t, mcm) + mcm.SetMarshalizerMeta(nil) + mcm.PrepareBroadcastHeaderValidator(header, make(map[uint32][]byte), make(map[string][][]byte), 0, make([]byte, 0)) + }) + t.Run("Err on SetHeaderForValidator", func(t *testing.T) { + t.Parallel() + + args := createDefaultMetaChainArgs() + checkVarModified := false + delayedBroadcaster := &consensusMock.DelayedBroadcasterMock{ + SetHeaderForValidatorCalled: func(vData *shared.ValidatorHeaderBroadcastData) error { + checkVarModified = true + return expectedErr + }, + } + args.DelayedBroadcaster = delayedBroadcaster + + mcm, _ := broadcast.NewMetaChainMessenger(args) + require.NotNil(t, mcm) + header := &block.Header{} + mcm.PrepareBroadcastHeaderValidator(header, make(map[uint32][]byte), make(map[string][][]byte), 0, make([]byte, 0)) + assert.True(t, checkVarModified) + }) +} + +func TestMetaChainMessenger_BroadcastBlock(t *testing.T) { + t.Parallel() + + t.Run("Err nil blockData", func(t *testing.T) { + args := createDefaultMetaChainArgs() + mcm, _ := broadcast.NewMetaChainMessenger(args) + require.NotNil(t, mcm) + err := mcm.BroadcastBlock(nil, nil) + assert.NotNil(t, err) + }) +} + +func TestMetaChainMessenger_NewMetaChainMessengerFailSetBroadcast(t *testing.T) { + t.Parallel() + + args := createDefaultMetaChainArgs() + varModified := false + delayedBroadcaster := &consensusMock.DelayedBroadcasterMock{ + SetBroadcastHandlersCalled: func( + mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, + txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, + headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, + consensusMessageBroadcast func(message *consensus.Message) error) error { + varModified = true + return expectedErr + }, + } + args.DelayedBroadcaster = delayedBroadcaster + + mcm, err := broadcast.NewMetaChainMessenger(args) + assert.Nil(t, mcm) + assert.NotNil(t, err) + assert.True(t, varModified) +} diff --git a/consensus/broadcast/shardChainMessenger.go b/consensus/broadcast/shardChainMessenger.go index ac7485a8d1f..233055bc9f8 100644 --- a/consensus/broadcast/shardChainMessenger.go +++ b/consensus/broadcast/shardChainMessenger.go @@ -7,8 +7,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/process/factory" ) @@ -37,35 +39,24 @@ func NewShardChainMessenger( } cm := &commonMessenger{ - marshalizer: args.Marshalizer, - hasher: args.Hasher, - messenger: args.Messenger, - shardCoordinator: args.ShardCoordinator, - peerSignatureHandler: args.PeerSignatureHandler, - keysHandler: args.KeysHandler, - } - - dbbArgs := &ArgsDelayedBlockBroadcaster{ - InterceptorsContainer: args.InterceptorsContainer, - HeadersSubscriber: args.HeadersSubscriber, - LeaderCacheSize: args.MaxDelayCacheSize, - ValidatorCacheSize: args.MaxValidatorDelayCacheSize, - ShardCoordinator: args.ShardCoordinator, - AlarmScheduler: args.AlarmScheduler, - } - - dbb, err := NewDelayedBlockBroadcaster(dbbArgs) - if err != nil { - return nil, err + marshalizer: args.Marshalizer, + hasher: args.Hasher, + messenger: args.Messenger, + shardCoordinator: args.ShardCoordinator, + peerSignatureHandler: args.PeerSignatureHandler, + keysHandler: args.KeysHandler, + delayedBlockBroadcaster: args.DelayedBroadcaster, } - cm.delayedBlockBroadcaster = dbb - scm := &shardChainMessenger{ commonMessenger: cm, } - err = dbb.SetBroadcastHandlers(scm.BroadcastMiniBlocks, scm.BroadcastTransactions, scm.BroadcastHeader) + err = scm.delayedBlockBroadcaster.SetBroadcastHandlers( + scm.BroadcastMiniBlocks, + scm.BroadcastTransactions, + scm.BroadcastHeader, + scm.BroadcastConsensusMessage) if err != nil { return nil, err } @@ -136,6 +127,14 @@ func (scm *shardChainMessenger) BroadcastHeader(header data.HeaderHandler, pkByt return nil } +// BroadcastEquivalentProof will broadcast the proof for a header on the shard metachain common topic +func (scm *shardChainMessenger) BroadcastEquivalentProof(proof data.HeaderProofHandler, pkBytes []byte) error { + shardIdentifier := scm.shardCoordinator.CommunicationIdentifier(core.MetachainShardId) + topic := common.EquivalentProofsTopic + shardIdentifier + + return scm.broadcastEquivalentProof(proof, pkBytes, topic) +} + // BroadcastBlockDataLeader broadcasts the block data as consensus group leader func (scm *shardChainMessenger) BroadcastBlockDataLeader( header data.HeaderHandler, @@ -143,34 +142,60 @@ func (scm *shardChainMessenger) BroadcastBlockDataLeader( transactions map[string][][]byte, pkBytes []byte, ) error { - if check.IfNil(header) { - return spos.ErrNilHeader - } - if len(miniBlocks) == 0 { + if miniBlocks == nil { return nil } - - headerHash, err := core.CalculateHash(scm.marshalizer, scm.hasher, header) + dtb, err := scm.prepareDataToBroadcast(header, miniBlocks, transactions, 0, pkBytes) + if err != nil { + return err + } + err = scm.delayedBlockBroadcaster.SetLeaderData(dtb.delayedBroadcastData) if err != nil { return err } - metaMiniBlocks, metaTransactions := scm.extractMetaMiniBlocksAndTransactions(miniBlocks, transactions) + // TODO: analyze if we can treat it similar to equivalent proofs broadcast (on interceptors) + go scm.BroadcastBlockData(dtb.metaMiniBlocks, dtb.metaTransactions, pkBytes, common.ExtraDelayForBroadcastBlockInfo) + return nil +} - broadcastData := &delayedBroadcastData{ - headerHash: headerHash, - miniBlocksData: miniBlocks, - transactions: transactions, - pkBytes: pkBytes, - } +type dataToBroadcast struct { + delayedBroadcastData *shared.DelayedBroadcastData + metaMiniBlocks map[uint32][]byte + metaTransactions map[string][][]byte +} - err = scm.delayedBlockBroadcaster.SetLeaderData(broadcastData) +func (scm *shardChainMessenger) prepareDataToBroadcast( + header data.HeaderHandler, + miniBlocks map[uint32][]byte, + transactions map[string][][]byte, + order uint32, + pkBytes []byte, +) (*dataToBroadcast, error) { + if check.IfNil(header) { + return nil, spos.ErrNilHeader + } + headerHash, err := core.CalculateHash(scm.marshalizer, scm.hasher, header) if err != nil { - return err + return nil, err } - go scm.BroadcastBlockData(metaMiniBlocks, metaTransactions, pkBytes, common.ExtraDelayForBroadcastBlockInfo) - return nil + metaMiniBlocks, metaTransactions := scm.extractMetaMiniBlocksAndTransactions(miniBlocks, transactions) + + dtb := &dataToBroadcast{ + delayedBroadcastData: &shared.DelayedBroadcastData{ + Header: header, + HeaderHash: headerHash, + MiniBlocksData: miniBlocks, + Transactions: transactions, + Order: order, + PkBytes: pkBytes, + }, + metaMiniBlocks: metaMiniBlocks, + metaTransactions: metaTransactions, + } + + return dtb, nil } // PrepareBroadcastHeaderValidator prepares the validator header broadcast in case leader broadcast fails @@ -188,58 +213,70 @@ func (scm *shardChainMessenger) PrepareBroadcastHeaderValidator( headerHash, err := core.CalculateHash(scm.marshalizer, scm.hasher, header) if err != nil { - log.Error("shardChainMessenger.PrepareBroadcastHeaderValidator", "error", err) + log.Error("shardChainMessenger.PrepareBroadcastHeaderValidator CalculateHash", "error", err) return } - vData := &validatorHeaderBroadcastData{ - headerHash: headerHash, - header: header, - order: uint32(idx), - pkBytes: pkBytes, + vData := &shared.ValidatorHeaderBroadcastData{ + HeaderHash: headerHash, + Header: header, + Order: uint32(idx), + PkBytes: pkBytes, } err = scm.delayedBlockBroadcaster.SetHeaderForValidator(vData) if err != nil { - log.Error("shardChainMessenger.PrepareBroadcastHeaderValidator", "error", err) + log.Error("shardChainMessenger.PrepareBroadcastHeaderValidator SetHeaderForValidator", "error", err) return } } -// PrepareBroadcastBlockDataValidator prepares the validator block data broadcast in case leader broadcast fails -func (scm *shardChainMessenger) PrepareBroadcastBlockDataValidator( +// PrepareBroadcastBlockDataWithEquivalentProofs prepares the data to be broadcast when equivalent proofs are activated +func (scm *shardChainMessenger) PrepareBroadcastBlockDataWithEquivalentProofs( header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, - idx int, pkBytes []byte, ) { - if check.IfNil(header) { - log.Error("shardChainMessenger.PrepareBroadcastBlockDataValidator", "error", spos.ErrNilHeader) + if len(miniBlocks) == 0 { return } - if len(miniBlocks) == 0 { + dtb, err := scm.prepareDataToBroadcast(header, miniBlocks, transactions, 0, pkBytes) + if err != nil { + log.Error("shardChainMessenger.PrepareBroadcastBlockDataWithEquivalentProofs prepareDataToBroadcast", "error", err) return } - - headerHash, err := core.CalculateHash(scm.marshalizer, scm.hasher, header) + // everyone broadcasts as if they were the leader + err = scm.delayedBlockBroadcaster.SetLeaderData(dtb.delayedBroadcastData) if err != nil { - log.Error("shardChainMessenger.PrepareBroadcastBlockDataValidator", "error", err) + log.Error("shardChainMessenger.PrepareBroadcastBlockDataWithEquivalentProofs SetLeaderData", "error", err) return } - broadcastData := &delayedBroadcastData{ - headerHash: headerHash, - header: header, - miniBlocksData: miniBlocks, - transactions: transactions, - order: uint32(idx), - pkBytes: pkBytes, + // TODO: consider moving this to the initial block broadcast - optimization + go scm.BroadcastBlockData(dtb.metaMiniBlocks, dtb.metaTransactions, pkBytes, common.ExtraDelayForBroadcastBlockInfo) +} + +// PrepareBroadcastBlockDataValidator prepares the validator block data broadcast in case leader broadcast fails +func (scm *shardChainMessenger) PrepareBroadcastBlockDataValidator( + header data.HeaderHandler, + miniBlocks map[uint32][]byte, + transactions map[string][][]byte, + idx int, + pkBytes []byte, +) { + if len(miniBlocks) == 0 { + return + } + dtb, err := scm.prepareDataToBroadcast(header, miniBlocks, transactions, uint32(idx), pkBytes) + if err != nil { + log.Error("shardChainMessenger.PrepareBroadcastBlockDataValidator prepareDataToBroadcast", "error", err) + return } - err = scm.delayedBlockBroadcaster.SetValidatorData(broadcastData) + err = scm.delayedBlockBroadcaster.SetValidatorData(dtb.delayedBroadcastData) if err != nil { - log.Error("shardChainMessenger.PrepareBroadcastBlockDataValidator", "error", err) + log.Error("shardChainMessenger.PrepareBroadcastBlockDataValidator SetValidatorData", "error", err) return } } diff --git a/consensus/broadcast/shardChainMessenger_test.go b/consensus/broadcast/shardChainMessenger_test.go index c81d2d98c28..7846ba12b0d 100644 --- a/consensus/broadcast/shardChainMessenger_test.go +++ b/consensus/broadcast/shardChainMessenger_test.go @@ -2,13 +2,24 @@ package broadcast_test import ( "bytes" + "errors" "testing" "time" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/consensus" + testscommonConsensus "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/pool" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus/broadcast" + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/p2p" @@ -17,9 +28,10 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" ) +var expectedErr = errors.New("expected error") + func createDelayData(prefix string) ([]byte, *block.Header, map[uint32][]byte, map[string][][]byte) { miniblocks := make(map[uint32][]byte) receiverShardID := uint32(1) @@ -44,8 +56,8 @@ func createInterceptorContainer() process.InterceptorsContainer { return &testscommon.InterceptorsContainerStub{ GetCalled: func(topic string) (process.Interceptor, error) { return &testscommon.InterceptorStub{ - ProcessReceivedMessageCalled: func(message p2p.MessageP2P) error { - return nil + ProcessReceivedMessageCalled: func(message p2p.MessageP2P) ([]byte, error) { + return nil, nil }, }, nil }, @@ -58,12 +70,13 @@ func createDefaultShardChainArgs() broadcast.ShardChainMessengerArgs { messengerMock := &p2pmocks.MessengerStub{} shardCoordinatorMock := &mock.ShardCoordinatorMock{} singleSignerMock := &mock.SingleSignerMock{} - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptorsContainer := createInterceptorContainer() peerSigHandler := &mock.PeerSignatureHandler{ Signer: singleSignerMock, } - alarmScheduler := &mock.AlarmSchedulerStub{} + alarmScheduler := &testscommon.AlarmSchedulerStub{} + delayedBroadcaster := &testscommonConsensus.DelayedBroadcasterMock{} return broadcast.ShardChainMessengerArgs{ CommonMessengerArgs: broadcast.CommonMessengerArgs{ @@ -78,6 +91,20 @@ func createDefaultShardChainArgs() broadcast.ShardChainMessengerArgs { MaxValidatorDelayCacheSize: 1, AlarmScheduler: alarmScheduler, KeysHandler: &testscommon.KeysHandlerStub{}, + DelayedBroadcaster: delayedBroadcaster, + }, + } +} + +func newBlockWithEmptyMiniblock() *block.Body { + return &block.Body{ + MiniBlocks: []*block.MiniBlock{ + { + TxHashes: [][]byte{}, + ReceiverShardID: 0, + SenderShardID: 0, + Type: 0, + }, }, } } @@ -85,6 +112,7 @@ func createDefaultShardChainArgs() broadcast.ShardChainMessengerArgs { func TestShardChainMessenger_NewShardChainMessengerNilMarshalizerShouldFail(t *testing.T) { args := createDefaultShardChainArgs() args.Marshalizer = nil + scm, err := broadcast.NewShardChainMessenger(args) assert.Nil(t, scm) @@ -136,6 +164,15 @@ func TestShardChainMessenger_NewShardChainMessengerNilHeadersSubscriberShouldFai assert.Equal(t, spos.ErrNilHeadersSubscriber, err) } +func TestShardChainMessenger_NilDelayedBroadcasterShouldError(t *testing.T) { + args := createDefaultShardChainArgs() + args.DelayedBroadcaster = nil + scm, err := broadcast.NewShardChainMessenger(args) + + assert.Nil(t, scm) + assert.Equal(t, broadcast.ErrNilDelayedBroadcaster, err) +} + func TestShardChainMessenger_NilKeysHandlerShouldError(t *testing.T) { args := createDefaultShardChainArgs() args.KeysHandler = nil @@ -154,6 +191,25 @@ func TestShardChainMessenger_NewShardChainMessengerShouldWork(t *testing.T) { assert.False(t, scm.IsInterfaceNil()) } +func TestShardChainMessenger_NewShardChainMessengerShouldErr(t *testing.T) { + + args := createDefaultShardChainArgs() + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetBroadcastHandlersCalled: func( + mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, + txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, + headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, + consensusMessageBroadcast func(message *consensus.Message) error, + ) error { + return expectedErr + }} + + _, err := broadcast.NewShardChainMessenger(args) + + assert.Equal(t, expectedErr, err) + +} + func TestShardChainMessenger_BroadcastBlockShouldErrNilBody(t *testing.T) { args := createDefaultShardChainArgs() scm, _ := broadcast.NewShardChainMessenger(args) @@ -170,6 +226,14 @@ func TestShardChainMessenger_BroadcastBlockShouldErrNilHeader(t *testing.T) { assert.Equal(t, spos.ErrNilHeader, err) } +func TestShardChainMessenger_BroadcastBlockShouldErrMiniBlockEmpty(t *testing.T) { + args := createDefaultShardChainArgs() + scm, _ := broadcast.NewShardChainMessenger(args) + + err := scm.BroadcastBlock(newBlockWithEmptyMiniblock(), &block.Header{}) + assert.Equal(t, data.ErrMiniBlockEmpty, err) +} + func TestShardChainMessenger_BroadcastBlockShouldErrMockMarshalizer(t *testing.T) { marshalizer := mock.MarshalizerMock{ Fail: true, @@ -363,6 +427,19 @@ func TestShardChainMessenger_BroadcastHeaderNilHeaderShouldErr(t *testing.T) { assert.Equal(t, spos.ErrNilHeader, err) } +func TestShardChainMessenger_BroadcastHeaderShouldErr(t *testing.T) { + marshalizer := mock.MarshalizerMock{ + Fail: true, + } + + args := createDefaultShardChainArgs() + args.Marshalizer = marshalizer + scm, _ := broadcast.NewShardChainMessenger(args) + + err := scm.BroadcastHeader(&block.MetaBlock{Nonce: 10}, []byte("pk bytes")) + assert.Equal(t, mock.ErrMockMarshalizer, err) +} + func TestShardChainMessenger_BroadcastHeaderShouldWork(t *testing.T) { channelBroadcastCalled := make(chan bool, 1) channelBroadcastUsingPrivateKeyCalled := make(chan bool, 1) @@ -439,6 +516,41 @@ func TestShardChainMessenger_BroadcastBlockDataLeaderNilMiniblocksShouldReturnNi assert.Nil(t, err) } +func TestShardChainMessenger_BroadcastBlockDataLeaderShouldErr(t *testing.T) { + marshalizer := mock.MarshalizerMock{ + Fail: true, + } + + args := createDefaultShardChainArgs() + args.Marshalizer = marshalizer + + scm, _ := broadcast.NewShardChainMessenger(args) + + _, header, miniblocks, transactions := createDelayData("1") + + err := scm.BroadcastBlockDataLeader(header, miniblocks, transactions, []byte("pk bytes")) + assert.Equal(t, mock.ErrMockMarshalizer, err) +} + +func TestShardChainMessenger_BroadcastBlockDataLeaderShouldErrDelayedBroadcaster(t *testing.T) { + + args := createDefaultShardChainArgs() + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetLeaderDataCalled: func(data *shared.DelayedBroadcastData) error { + return expectedErr + }} + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + _, header, miniblocks, transactions := createDelayData("1") + + err := scm.BroadcastBlockDataLeader(header, miniblocks, transactions, []byte("pk bytes")) + + assert.Equal(t, expectedErr, err) +} + func TestShardChainMessenger_BroadcastBlockDataLeaderShouldTriggerWaitingDelayedMessage(t *testing.T) { broadcastWasCalled := atomic.Flag{} broadcastUsingPrivateKeyWasCalled := atomic.Flag{} @@ -457,6 +569,18 @@ func TestShardChainMessenger_BroadcastBlockDataLeaderShouldTriggerWaitingDelayed return bytes.Equal(pkBytes, nodePkBytes) }, } + argsDelayedBroadcaster := broadcast.ArgsDelayedBlockBroadcaster{ + InterceptorsContainer: args.InterceptorsContainer, + HeadersSubscriber: args.HeadersSubscriber, + ShardCoordinator: args.ShardCoordinator, + LeaderCacheSize: args.MaxDelayCacheSize, + ValidatorCacheSize: args.MaxDelayCacheSize, + AlarmScheduler: args.AlarmScheduler, + } + + // Using real component in order to properly simulate the expected behavior + args.DelayedBroadcaster, _ = broadcast.NewDelayedBlockBroadcaster(&argsDelayedBroadcaster) + scm, _ := broadcast.NewShardChainMessenger(args) t.Run("original public key of the node", func(t *testing.T) { @@ -488,3 +612,190 @@ func TestShardChainMessenger_BroadcastBlockDataLeaderShouldTriggerWaitingDelayed assert.True(t, broadcastUsingPrivateKeyWasCalled.IsSet()) }) } + +func TestShardChainMessenger_PrepareBroadcastHeaderValidatorShouldFailHeaderNil(t *testing.T) { + + pkBytes := make([]byte, 32) + args := createDefaultShardChainArgs() + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetHeaderForValidatorCalled: func(vData *shared.ValidatorHeaderBroadcastData) error { + require.Fail(t, "SetHeaderForValidator should not be called") + return nil + }} + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastHeaderValidator(nil, nil, nil, 1, pkBytes) +} + +func TestShardChainMessenger_PrepareBroadcastHeaderValidatorShouldFailCalculateHashErr(t *testing.T) { + + pkBytes := make([]byte, 32) + headerMock := &testscommon.HeaderHandlerStub{} + + args := createDefaultShardChainArgs() + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetHeaderForValidatorCalled: func(vData *shared.ValidatorHeaderBroadcastData) error { + require.Fail(t, "SetHeaderForValidator should not be called") + return nil + }} + + args.Marshalizer = &testscommon.MarshallerStub{MarshalCalled: func(obj interface{}) ([]byte, error) { + return nil, expectedErr + }} + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastHeaderValidator(headerMock, nil, nil, 1, pkBytes) +} + +func TestShardChainMessenger_PrepareBroadcastHeaderValidatorShouldWork(t *testing.T) { + + pkBytes := make([]byte, 32) + headerMock := &testscommon.HeaderHandlerStub{} + + args := createDefaultShardChainArgs() + + varSetHeaderForValidatorCalled := false + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetHeaderForValidatorCalled: func(vData *shared.ValidatorHeaderBroadcastData) error { + varSetHeaderForValidatorCalled = true + return nil + }} + + args.Marshalizer = &testscommon.MarshallerStub{MarshalCalled: func(obj interface{}) ([]byte, error) { + return nil, nil + }} + args.Hasher = &testscommon.HasherStub{ComputeCalled: func(s string) []byte { + return nil + }} + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastHeaderValidator(headerMock, nil, nil, 1, pkBytes) + + assert.True(t, varSetHeaderForValidatorCalled) +} + +func TestShardChainMessenger_PrepareBroadcastBlockDataValidatorShouldFailHeaderNil(t *testing.T) { + + pkBytes := make([]byte, 32) + args := createDefaultShardChainArgs() + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetValidatorDataCalled: func(data *shared.DelayedBroadcastData) error { + require.Fail(t, "SetValidatorData should not be called") + return nil + }} + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastBlockDataValidator(nil, nil, nil, 1, pkBytes) +} + +func TestShardChainMessenger_PrepareBroadcastBlockDataValidatorShouldFailMiniBlocksLenZero(t *testing.T) { + + pkBytes := make([]byte, 32) + miniBlocks := make(map[uint32][]byte) + headerMock := &testscommon.HeaderHandlerStub{} + + args := createDefaultShardChainArgs() + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetValidatorDataCalled: func(data *shared.DelayedBroadcastData) error { + require.Fail(t, "SetValidatorData should not be called") + return nil + }} + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastBlockDataValidator(headerMock, miniBlocks, nil, 1, pkBytes) +} + +func TestShardChainMessenger_PrepareBroadcastBlockDataValidatorShouldFailCalculateHashErr(t *testing.T) { + + pkBytes := make([]byte, 32) + miniBlocks := map[uint32][]byte{1: {}} + headerMock := &testscommon.HeaderHandlerStub{} + + args := createDefaultShardChainArgs() + + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetValidatorDataCalled: func(data *shared.DelayedBroadcastData) error { + require.Fail(t, "SetValidatorData should not be called") + return nil + }} + + args.Marshalizer = &testscommon.MarshallerStub{ + MarshalCalled: func(obj interface{}) ([]byte, error) { + return nil, expectedErr + }, + } + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastBlockDataValidator(headerMock, miniBlocks, nil, 1, pkBytes) +} + +func TestShardChainMessenger_PrepareBroadcastBlockDataValidatorShouldWork(t *testing.T) { + + pkBytes := make([]byte, 32) + miniBlocks := map[uint32][]byte{1: {}} + headerMock := &testscommon.HeaderHandlerStub{} + + args := createDefaultShardChainArgs() + + varSetValidatorDataCalled := false + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + SetValidatorDataCalled: func(data *shared.DelayedBroadcastData) error { + varSetValidatorDataCalled = true + return nil + }} + + args.Marshalizer = &testscommon.MarshallerStub{ + MarshalCalled: func(obj interface{}) ([]byte, error) { + return nil, nil + }, + } + + args.Hasher = &testscommon.HasherStub{ + ComputeCalled: func(s string) []byte { + return nil + }, + } + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.PrepareBroadcastBlockDataValidator(headerMock, miniBlocks, nil, 1, pkBytes) + + assert.True(t, varSetValidatorDataCalled) +} + +func TestShardChainMessenger_CloseShouldWork(t *testing.T) { + + args := createDefaultShardChainArgs() + + varCloseCalled := false + args.DelayedBroadcaster = &testscommonConsensus.DelayedBroadcasterMock{ + CloseCalled: func() { + varCloseCalled = true + }, + } + + scm, _ := broadcast.NewShardChainMessenger(args) + require.NotNil(t, scm) + + scm.Close() + assert.True(t, varCloseCalled) + +} diff --git a/consensus/broadcast/shared/types.go b/consensus/broadcast/shared/types.go new file mode 100644 index 00000000000..216cd5987b8 --- /dev/null +++ b/consensus/broadcast/shared/types.go @@ -0,0 +1,26 @@ +package shared + +import ( + "github.com/multiversx/mx-chain-core-go/data" +) + +// DelayedBroadcastData is exported to be accessible in delayedBroadcasterMock +type DelayedBroadcastData struct { + HeaderHash []byte + Header data.HeaderHandler + MiniBlocksData map[uint32][]byte + MiniBlockHashes map[string]map[string]struct{} + Transactions map[string][][]byte + Order uint32 + PkBytes []byte +} + +// ValidatorHeaderBroadcastData is exported to be accessible in delayedBroadcasterMock +type ValidatorHeaderBroadcastData struct { + HeaderHash []byte + Header data.HeaderHandler + MetaMiniBlocksData map[uint32][]byte + MetaTransactionsData map[string][][]byte + Order uint32 + PkBytes []byte +} diff --git a/consensus/chronology/chronology.go b/consensus/chronology/chronology.go index 1b20bc1dc03..f4dc90604b7 100644 --- a/consensus/chronology/chronology.go +++ b/consensus/chronology/chronology.go @@ -10,10 +10,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/closing" "github.com/multiversx/mx-chain-core-go/display" + "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/ntp" - "github.com/multiversx/mx-chain-logger-go" ) var _ consensus.ChronologyHandler = (*chronology)(nil) @@ -103,6 +104,7 @@ func (chr *chronology) RemoveAllSubrounds() { chr.subrounds = make(map[int]int) chr.subroundHandlers = make([]consensus.SubroundHandler, 0) + chr.subroundId = srBeforeStartRound chr.mutSubrounds.Unlock() } @@ -118,6 +120,9 @@ func (chr *chronology) StartRounds() { } func (chr *chronology) startRounds(ctx context.Context) { + // force a round update to initialize the round + roundHandlerWithRevert := chr.roundHandler.(consensus.RoundHandlerConsensusSwitch) + roundHandlerWithRevert.RevertOneRound() for { select { case <-ctx.Done(): diff --git a/consensus/chronology/chronology_test.go b/consensus/chronology/chronology_test.go index 978d898834c..f7a9b70adb5 100644 --- a/consensus/chronology/chronology_test.go +++ b/consensus/chronology/chronology_test.go @@ -5,11 +5,14 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/chronology" "github.com/multiversx/mx-chain-go/consensus/mock" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) func initSubroundHandlerMock() *mock.SubroundHandlerMock { @@ -115,7 +118,7 @@ func TestChronology_StartRoundShouldReturnWhenRoundIndexIsNegative(t *testing.T) t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} roundHandlerMock.IndexCalled = func() int64 { return -1 } @@ -149,7 +152,7 @@ func TestChronology_StartRoundShouldReturnWhenDoWorkReturnsFalse(t *testing.T) { t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} roundHandlerMock.UpdateRound(roundHandlerMock.TimeStamp(), roundHandlerMock.TimeStamp().Add(roundHandlerMock.TimeDuration())) arg.RoundHandler = roundHandlerMock chr, _ := chronology.NewChronology(arg) @@ -166,7 +169,7 @@ func TestChronology_StartRoundShouldWork(t *testing.T) { t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} roundHandlerMock.UpdateRound(roundHandlerMock.TimeStamp(), roundHandlerMock.TimeStamp().Add(roundHandlerMock.TimeDuration())) arg.RoundHandler = roundHandlerMock chr, _ := chronology.NewChronology(arg) @@ -219,7 +222,7 @@ func TestChronology_InitRoundShouldNotSetSubroundWhenRoundIndexIsNegative(t *tes t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} arg.RoundHandler = roundHandlerMock arg.GenesisTime = arg.SyncTimer.CurrentTime() chr, _ := chronology.NewChronology(arg) @@ -240,7 +243,7 @@ func TestChronology_InitRoundShouldSetSubroundWhenRoundIndexIsPositive(t *testin t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} roundHandlerMock.UpdateRound(roundHandlerMock.TimeStamp(), roundHandlerMock.TimeStamp().Add(roundHandlerMock.TimeDuration())) arg.RoundHandler = roundHandlerMock arg.GenesisTime = arg.SyncTimer.CurrentTime() @@ -257,7 +260,7 @@ func TestChronology_StartRoundShouldNotUpdateRoundWhenCurrentRoundIsNotFinished( t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} arg.RoundHandler = roundHandlerMock arg.GenesisTime = arg.SyncTimer.CurrentTime() chr, _ := chronology.NewChronology(arg) @@ -271,7 +274,7 @@ func TestChronology_StartRoundShouldNotUpdateRoundWhenCurrentRoundIsNotFinished( func TestChronology_StartRoundShouldUpdateRoundWhenCurrentRoundIsFinished(t *testing.T) { t.Parallel() arg := getDefaultChronologyArg() - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} arg.RoundHandler = roundHandlerMock arg.GenesisTime = arg.SyncTimer.CurrentTime() chr, _ := chronology.NewChronology(arg) @@ -315,9 +318,97 @@ func TestChronology_CheckIfStatusHandlerWorks(t *testing.T) { func getDefaultChronologyArg() chronology.ArgChronology { return chronology.ArgChronology{ GenesisTime: time.Now(), - RoundHandler: &mock.RoundHandlerMock{}, - SyncTimer: &mock.SyncTimerMock{}, + RoundHandler: &consensusMocks.RoundHandlerMock{}, + SyncTimer: &consensusMocks.SyncTimerMock{}, AppStatusHandler: statusHandlerMock.NewAppStatusHandlerMock(), Watchdog: &mock.WatchdogMock{}, } } + +func TestChronology_CloseWatchDogStop(t *testing.T) { + t.Parallel() + + arg := getDefaultChronologyArg() + stopCalled := false + arg.Watchdog = &mock.WatchdogMock{ + StopCalled: func(alarmID string) { + stopCalled = true + }, + } + + chr, err := chronology.NewChronology(arg) + require.Nil(t, err) + chr.SetCancelFunc(nil) + + err = chr.Close() + assert.Nil(t, err) + assert.True(t, stopCalled) +} + +func TestChronology_Close(t *testing.T) { + t.Parallel() + + arg := getDefaultChronologyArg() + stopCalled := false + arg.Watchdog = &mock.WatchdogMock{ + StopCalled: func(alarmID string) { + stopCalled = true + }, + } + + chr, err := chronology.NewChronology(arg) + require.Nil(t, err) + + cancelCalled := false + chr.SetCancelFunc(func() { + cancelCalled = true + }) + + err = chr.Close() + assert.Nil(t, err) + assert.True(t, stopCalled) + assert.True(t, cancelCalled) +} + +func TestChronology_StartRounds(t *testing.T) { + t.Parallel() + + arg := getDefaultChronologyArg() + + chr, err := chronology.NewChronology(arg) + require.Nil(t, err) + doneFuncCalled := false + + ctx := &mock.ContextMock{ + DoneFunc: func() <-chan struct{} { + done := make(chan struct{}) + close(done) + doneFuncCalled = true + return done + }, + } + chr.StartRoundsTest(ctx) + assert.True(t, doneFuncCalled) +} + +func TestChronology_StartRoundsShouldWork(t *testing.T) { + t.Parallel() + + arg := getDefaultChronologyArg() + roundHandlerMock := &consensusMocks.RoundHandlerMock{} + roundHandlerMock.UpdateRound(roundHandlerMock.TimeStamp(), roundHandlerMock.TimeStamp().Add(roundHandlerMock.TimeDuration())) + arg.RoundHandler = roundHandlerMock + chr, _ := chronology.NewChronology(arg) + + srm := initSubroundHandlerMock() + srm.DoWorkCalled = func(roundHandler consensus.RoundHandler) bool { + return true + } + chr.AddSubround(srm) + chr.SetSubroundId(1) + chr.StartRounds() + defer chr.Close() + + assert.Equal(t, srm.Next(), chr.SubroundId()) + time.Sleep(time.Millisecond * 10) +} diff --git a/consensus/chronology/export_test.go b/consensus/chronology/export_test.go index 39ff4cab99f..b3a35131597 100644 --- a/consensus/chronology/export_test.go +++ b/consensus/chronology/export_test.go @@ -3,6 +3,8 @@ package chronology import ( "context" + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/consensus" ) @@ -37,3 +39,18 @@ func (chr *chronology) UpdateRound() { func (chr *chronology) InitRound() { chr.initRound() } + +// StartRoundsTest calls the unexported startRounds function +func (chr *chronology) StartRoundsTest(ctx context.Context) { + chr.startRounds(ctx) +} + +// SetWatchdog sets the watchdog for chronology object +func (chr *chronology) SetWatchdog(watchdog core.WatchdogTimer) { + chr.watchdog = watchdog +} + +// SetCancelFunc sets cancelFunc for chronology object +func (chr *chronology) SetCancelFunc(cancelFunc func()) { + chr.cancelFunc = cancelFunc +} diff --git a/consensus/interface.go b/consensus/interface.go index aa8d9057bc4..27e2916110a 100644 --- a/consensus/interface.go +++ b/consensus/interface.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/p2p" ) @@ -25,6 +26,12 @@ type RoundHandler interface { IsInterfaceNil() bool } +// RoundHandlerConsensusSwitch defines the actions which should be handled by a consensus switch round implementation +type RoundHandlerConsensusSwitch interface { + RevertOneRound() + IsInterfaceNil() bool +} + // SubroundHandler defines the actions which should be handled by a subround implementation type SubroundHandler interface { // DoWork implements of the subround's job @@ -61,12 +68,14 @@ type ChronologyHandler interface { type BroadcastMessenger interface { BroadcastBlock(data.BodyHandler, data.HeaderHandler) error BroadcastHeader(data.HeaderHandler, []byte) error + BroadcastEquivalentProof(proof data.HeaderProofHandler, pkBytes []byte) error BroadcastMiniBlocks(map[uint32][]byte, []byte) error BroadcastTransactions(map[string][][]byte, []byte) error BroadcastConsensusMessage(*Message) error BroadcastBlockDataLeader(header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, pkBytes []byte) error PrepareBroadcastHeaderValidator(header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, idx int, pkBytes []byte) PrepareBroadcastBlockDataValidator(header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, idx int, pkBytes []byte) + PrepareBroadcastBlockDataWithEquivalentProofs(header data.HeaderHandler, miniBlocks map[uint32][]byte, transactions map[string][][]byte, pkBytes []byte) IsInterfaceNil() bool } @@ -122,11 +131,14 @@ type HeaderSigVerifier interface { VerifyRandSeed(header data.HeaderHandler) error VerifyLeaderSignature(header data.HeaderHandler) error VerifySignature(header data.HeaderHandler) error + VerifySignatureForHash(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error + VerifyHeaderProof(headerProof data.HeaderProofHandler) error IsInterfaceNil() bool } // FallbackHeaderValidator defines the behaviour of a component able to signal when a fallback header validation could be applied type FallbackHeaderValidator interface { + ShouldApplyFallbackValidationForHeaderWith(shardID uint32, startOfEpochBlock bool, round uint64, prevHeaderHash []byte) bool ShouldApplyFallbackValidation(headerHandler data.HeaderHandler) bool IsInterfaceNil() bool } @@ -193,3 +205,24 @@ type KeysHandler interface { GetRedundancyStepInReason() string IsInterfaceNil() bool } + +// EquivalentProofsPool defines the behaviour of a proofs pool components +type EquivalentProofsPool interface { + AddProof(headerProof data.HeaderProofHandler) bool + GetProof(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) + GetProofByNonce(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) + HasProof(shardID uint32, headerHash []byte) bool + IsInterfaceNil() bool +} + +// ProofHandler defines the interface for a proof handler +type ProofHandler interface { + GetPubKeysBitmap() []byte + GetAggregatedSignature() []byte + GetHeaderHash() []byte + GetHeaderEpoch() uint32 + GetHeaderNonce() uint64 + GetHeaderShardId() uint32 + GetIsStartOfEpoch() bool + IsInterfaceNil() bool +} diff --git a/consensus/message.go b/consensus/message.go index f4396c05076..3e581673d17 100644 --- a/consensus/message.go +++ b/consensus/message.go @@ -1,7 +1,9 @@ //go:generate protoc -I=. -I=$GOPATH/src -I=$GOPATH/src/github.com/multiversx/protobuf/protobuf --gogoslick_out=. message.proto package consensus -import "github.com/multiversx/mx-chain-core-go/core" +import ( + "github.com/multiversx/mx-chain-core-go/core" +) // MessageType specifies what type of message was received type MessageType int diff --git a/consensus/mock/alarmSchedulerStub.go b/consensus/mock/alarmSchedulerStub.go deleted file mode 100644 index fe2e7597036..00000000000 --- a/consensus/mock/alarmSchedulerStub.go +++ /dev/null @@ -1,45 +0,0 @@ -package mock - -import ( - "time" -) - -type AlarmSchedulerStub struct { - AddCalled func(func(alarmID string), time.Duration, string) - CancelCalled func(string) - CloseCalled func() - ResetCalled func(string) -} - -// Add - -func (a *AlarmSchedulerStub) Add(callback func(alarmID string), duration time.Duration, alarmID string) { - if a.AddCalled != nil { - a.AddCalled(callback, duration, alarmID) - } -} - -// Cancel - -func (a *AlarmSchedulerStub) Cancel(alarmID string) { - if a.CancelCalled != nil { - a.CancelCalled(alarmID) - } -} - -// Close - -func (a *AlarmSchedulerStub) Close() { - if a.CloseCalled != nil { - a.CloseCalled() - } -} - -// Reset - -func (a *AlarmSchedulerStub) Reset(alarmID string) { - if a.ResetCalled != nil { - a.ResetCalled(alarmID) - } -} - -// IsInterfaceNil - -func (a *AlarmSchedulerStub) IsInterfaceNil() bool { - return a == nil -} diff --git a/consensus/mock/consensusDataContainerMock.go b/consensus/mock/consensusDataContainerMock.go deleted file mode 100644 index 88f837b1da1..00000000000 --- a/consensus/mock/consensusDataContainerMock.go +++ /dev/null @@ -1,246 +0,0 @@ -package mock - -import ( - "github.com/multiversx/mx-chain-core-go/data" - "github.com/multiversx/mx-chain-core-go/hashing" - "github.com/multiversx/mx-chain-core-go/marshal" - cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" - "github.com/multiversx/mx-chain-go/consensus" - "github.com/multiversx/mx-chain-go/epochStart" - "github.com/multiversx/mx-chain-go/ntp" - "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-go/sharding" - "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" -) - -// ConsensusCoreMock - -type ConsensusCoreMock struct { - blockChain data.ChainHandler - blockProcessor process.BlockProcessor - headersSubscriber consensus.HeadersPoolSubscriber - bootstrapper process.Bootstrapper - broadcastMessenger consensus.BroadcastMessenger - chronologyHandler consensus.ChronologyHandler - hasher hashing.Hasher - marshalizer marshal.Marshalizer - multiSignerContainer cryptoCommon.MultiSignerContainer - roundHandler consensus.RoundHandler - shardCoordinator sharding.Coordinator - syncTimer ntp.SyncTimer - validatorGroupSelector nodesCoordinator.NodesCoordinator - epochStartNotifier epochStart.RegistrationHandler - antifloodHandler consensus.P2PAntifloodHandler - peerHonestyHandler consensus.PeerHonestyHandler - headerSigVerifier consensus.HeaderSigVerifier - fallbackHeaderValidator consensus.FallbackHeaderValidator - nodeRedundancyHandler consensus.NodeRedundancyHandler - scheduledProcessor consensus.ScheduledProcessor - messageSigningHandler consensus.P2PSigningHandler - peerBlacklistHandler consensus.PeerBlacklistHandler - signingHandler consensus.SigningHandler -} - -// GetAntiFloodHandler - -func (ccm *ConsensusCoreMock) GetAntiFloodHandler() consensus.P2PAntifloodHandler { - return ccm.antifloodHandler -} - -// Blockchain - -func (ccm *ConsensusCoreMock) Blockchain() data.ChainHandler { - return ccm.blockChain -} - -// BlockProcessor - -func (ccm *ConsensusCoreMock) BlockProcessor() process.BlockProcessor { - return ccm.blockProcessor -} - -// HeadersPoolSubscriber - -func (ccm *ConsensusCoreMock) HeadersPoolSubscriber() consensus.HeadersPoolSubscriber { - return ccm.headersSubscriber -} - -// BootStrapper - -func (ccm *ConsensusCoreMock) BootStrapper() process.Bootstrapper { - return ccm.bootstrapper -} - -// BroadcastMessenger - -func (ccm *ConsensusCoreMock) BroadcastMessenger() consensus.BroadcastMessenger { - return ccm.broadcastMessenger -} - -// Chronology - -func (ccm *ConsensusCoreMock) Chronology() consensus.ChronologyHandler { - return ccm.chronologyHandler -} - -// Hasher - -func (ccm *ConsensusCoreMock) Hasher() hashing.Hasher { - return ccm.hasher -} - -// Marshalizer - -func (ccm *ConsensusCoreMock) Marshalizer() marshal.Marshalizer { - return ccm.marshalizer -} - -// MultiSignerContainer - -func (ccm *ConsensusCoreMock) MultiSignerContainer() cryptoCommon.MultiSignerContainer { - return ccm.multiSignerContainer -} - -// RoundHandler - -func (ccm *ConsensusCoreMock) RoundHandler() consensus.RoundHandler { - return ccm.roundHandler -} - -// ShardCoordinator - -func (ccm *ConsensusCoreMock) ShardCoordinator() sharding.Coordinator { - return ccm.shardCoordinator -} - -// SyncTimer - -func (ccm *ConsensusCoreMock) SyncTimer() ntp.SyncTimer { - return ccm.syncTimer -} - -// NodesCoordinator - -func (ccm *ConsensusCoreMock) NodesCoordinator() nodesCoordinator.NodesCoordinator { - return ccm.validatorGroupSelector -} - -// EpochStartRegistrationHandler - -func (ccm *ConsensusCoreMock) EpochStartRegistrationHandler() epochStart.RegistrationHandler { - return ccm.epochStartNotifier -} - -// SetBlockchain - -func (ccm *ConsensusCoreMock) SetBlockchain(blockChain data.ChainHandler) { - ccm.blockChain = blockChain -} - -// SetBlockProcessor - -func (ccm *ConsensusCoreMock) SetBlockProcessor(blockProcessor process.BlockProcessor) { - ccm.blockProcessor = blockProcessor -} - -// SetBootStrapper - -func (ccm *ConsensusCoreMock) SetBootStrapper(bootstrapper process.Bootstrapper) { - ccm.bootstrapper = bootstrapper -} - -// SetBroadcastMessenger - -func (ccm *ConsensusCoreMock) SetBroadcastMessenger(broadcastMessenger consensus.BroadcastMessenger) { - ccm.broadcastMessenger = broadcastMessenger -} - -// SetChronology - -func (ccm *ConsensusCoreMock) SetChronology(chronologyHandler consensus.ChronologyHandler) { - ccm.chronologyHandler = chronologyHandler -} - -// SetHasher - -func (ccm *ConsensusCoreMock) SetHasher(hasher hashing.Hasher) { - ccm.hasher = hasher -} - -// SetMarshalizer - -func (ccm *ConsensusCoreMock) SetMarshalizer(marshalizer marshal.Marshalizer) { - ccm.marshalizer = marshalizer -} - -// SetMultiSignerContainer - -func (ccm *ConsensusCoreMock) SetMultiSignerContainer(multiSignerContainer cryptoCommon.MultiSignerContainer) { - ccm.multiSignerContainer = multiSignerContainer -} - -// SetRoundHandler - -func (ccm *ConsensusCoreMock) SetRoundHandler(roundHandler consensus.RoundHandler) { - ccm.roundHandler = roundHandler -} - -// SetShardCoordinator - -func (ccm *ConsensusCoreMock) SetShardCoordinator(shardCoordinator sharding.Coordinator) { - ccm.shardCoordinator = shardCoordinator -} - -// SetSyncTimer - -func (ccm *ConsensusCoreMock) SetSyncTimer(syncTimer ntp.SyncTimer) { - ccm.syncTimer = syncTimer -} - -// SetValidatorGroupSelector - -func (ccm *ConsensusCoreMock) SetValidatorGroupSelector(validatorGroupSelector nodesCoordinator.NodesCoordinator) { - ccm.validatorGroupSelector = validatorGroupSelector -} - -// PeerHonestyHandler - -func (ccm *ConsensusCoreMock) PeerHonestyHandler() consensus.PeerHonestyHandler { - return ccm.peerHonestyHandler -} - -// HeaderSigVerifier - -func (ccm *ConsensusCoreMock) HeaderSigVerifier() consensus.HeaderSigVerifier { - return ccm.headerSigVerifier -} - -// SetHeaderSigVerifier - -func (ccm *ConsensusCoreMock) SetHeaderSigVerifier(headerSigVerifier consensus.HeaderSigVerifier) { - ccm.headerSigVerifier = headerSigVerifier -} - -// FallbackHeaderValidator - -func (ccm *ConsensusCoreMock) FallbackHeaderValidator() consensus.FallbackHeaderValidator { - return ccm.fallbackHeaderValidator -} - -// SetFallbackHeaderValidator - -func (ccm *ConsensusCoreMock) SetFallbackHeaderValidator(fallbackHeaderValidator consensus.FallbackHeaderValidator) { - ccm.fallbackHeaderValidator = fallbackHeaderValidator -} - -// NodeRedundancyHandler - -func (ccm *ConsensusCoreMock) NodeRedundancyHandler() consensus.NodeRedundancyHandler { - return ccm.nodeRedundancyHandler -} - -// ScheduledProcessor - -func (ccm *ConsensusCoreMock) ScheduledProcessor() consensus.ScheduledProcessor { - return ccm.scheduledProcessor -} - -// SetNodeRedundancyHandler - -func (ccm *ConsensusCoreMock) SetNodeRedundancyHandler(nodeRedundancyHandler consensus.NodeRedundancyHandler) { - ccm.nodeRedundancyHandler = nodeRedundancyHandler -} - -// MessageSigningHandler - -func (ccm *ConsensusCoreMock) MessageSigningHandler() consensus.P2PSigningHandler { - return ccm.messageSigningHandler -} - -// SetMessageSigningHandler - -func (ccm *ConsensusCoreMock) SetMessageSigningHandler(messageSigningHandler consensus.P2PSigningHandler) { - ccm.messageSigningHandler = messageSigningHandler -} - -// PeerBlacklistHandler will return the peer blacklist handler -func (ccm *ConsensusCoreMock) PeerBlacklistHandler() consensus.PeerBlacklistHandler { - return ccm.peerBlacklistHandler -} - -// SigningHandler - -func (ccm *ConsensusCoreMock) SigningHandler() consensus.SigningHandler { - return ccm.signingHandler -} - -// SetSigningHandler - -func (ccm *ConsensusCoreMock) SetSigningHandler(signingHandler consensus.SigningHandler) { - ccm.signingHandler = signingHandler -} - -// IsInterfaceNil returns true if there is no value under the interface -func (ccm *ConsensusCoreMock) IsInterfaceNil() bool { - return ccm == nil -} diff --git a/consensus/mock/consensusStateMock.go b/consensus/mock/consensusStateMock.go deleted file mode 100644 index fb4fb708449..00000000000 --- a/consensus/mock/consensusStateMock.go +++ /dev/null @@ -1,137 +0,0 @@ -package mock - -import ( - "github.com/multiversx/mx-chain-go/consensus" - "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" -) - -// ConsensusStateMock - -type ConsensusStateMock struct { - ResetConsensusStateCalled func() - IsNodeLeaderInCurrentRoundCalled func(node string) bool - IsSelfLeaderInCurrentRoundCalled func() bool - GetLeaderCalled func() (string, error) - GetNextConsensusGroupCalled func(randomSource string, vgs nodesCoordinator.NodesCoordinator) ([]string, error) - IsConsensusDataSetCalled func() bool - IsConsensusDataEqualCalled func(data []byte) bool - IsJobDoneCalled func(node string, currentSubroundId int) bool - IsSelfJobDoneCalled func(currentSubroundId int) bool - IsCurrentSubroundFinishedCalled func(currentSubroundId int) bool - IsNodeSelfCalled func(node string) bool - IsBlockBodyAlreadyReceivedCalled func() bool - IsHeaderAlreadyReceivedCalled func() bool - CanDoSubroundJobCalled func(currentSubroundId int) bool - CanProcessReceivedMessageCalled func(cnsDta consensus.Message, currentRoundIndex int32, currentSubroundId int) bool - GenerateBitmapCalled func(subroundId int) []byte - ProcessingBlockCalled func() bool - SetProcessingBlockCalled func(processingBlock bool) - ConsensusGroupSizeCalled func() int - SetThresholdCalled func(subroundId int, threshold int) -} - -// ResetConsensusState - -func (cnsm *ConsensusStateMock) ResetConsensusState() { - cnsm.ResetConsensusStateCalled() -} - -// IsNodeLeaderInCurrentRound - -func (cnsm *ConsensusStateMock) IsNodeLeaderInCurrentRound(node string) bool { - return cnsm.IsNodeLeaderInCurrentRoundCalled(node) -} - -// IsSelfLeaderInCurrentRound - -func (cnsm *ConsensusStateMock) IsSelfLeaderInCurrentRound() bool { - return cnsm.IsSelfLeaderInCurrentRoundCalled() -} - -// GetLeader - -func (cnsm *ConsensusStateMock) GetLeader() (string, error) { - return cnsm.GetLeaderCalled() -} - -// GetNextConsensusGroup - -func (cnsm *ConsensusStateMock) GetNextConsensusGroup( - randomSource string, - vgs nodesCoordinator.NodesCoordinator, -) ([]string, error) { - return cnsm.GetNextConsensusGroupCalled(randomSource, vgs) -} - -// IsConsensusDataSet - -func (cnsm *ConsensusStateMock) IsConsensusDataSet() bool { - return cnsm.IsConsensusDataSetCalled() -} - -// IsConsensusDataEqual - -func (cnsm *ConsensusStateMock) IsConsensusDataEqual(data []byte) bool { - return cnsm.IsConsensusDataEqualCalled(data) -} - -// IsJobDone - -func (cnsm *ConsensusStateMock) IsJobDone(node string, currentSubroundId int) bool { - return cnsm.IsJobDoneCalled(node, currentSubroundId) -} - -// IsSelfJobDone - -func (cnsm *ConsensusStateMock) IsSelfJobDone(currentSubroundId int) bool { - return cnsm.IsSelfJobDoneCalled(currentSubroundId) -} - -// IsCurrentSubroundFinished - -func (cnsm *ConsensusStateMock) IsCurrentSubroundFinished(currentSubroundId int) bool { - return cnsm.IsCurrentSubroundFinishedCalled(currentSubroundId) -} - -// IsNodeSelf - -func (cnsm *ConsensusStateMock) IsNodeSelf(node string) bool { - return cnsm.IsNodeSelfCalled(node) -} - -// IsBlockBodyAlreadyReceived - -func (cnsm *ConsensusStateMock) IsBlockBodyAlreadyReceived() bool { - return cnsm.IsBlockBodyAlreadyReceivedCalled() -} - -// IsHeaderAlreadyReceived - -func (cnsm *ConsensusStateMock) IsHeaderAlreadyReceived() bool { - return cnsm.IsHeaderAlreadyReceivedCalled() -} - -// CanDoSubroundJob - -func (cnsm *ConsensusStateMock) CanDoSubroundJob(currentSubroundId int) bool { - return cnsm.CanDoSubroundJobCalled(currentSubroundId) -} - -// CanProcessReceivedMessage - -func (cnsm *ConsensusStateMock) CanProcessReceivedMessage( - cnsDta consensus.Message, - currentRoundIndex int32, - currentSubroundId int, -) bool { - return cnsm.CanProcessReceivedMessageCalled(cnsDta, currentRoundIndex, currentSubroundId) -} - -// GenerateBitmap - -func (cnsm *ConsensusStateMock) GenerateBitmap(subroundId int) []byte { - return cnsm.GenerateBitmapCalled(subroundId) -} - -// ProcessingBlock - -func (cnsm *ConsensusStateMock) ProcessingBlock() bool { - return cnsm.ProcessingBlockCalled() -} - -// SetProcessingBlock - -func (cnsm *ConsensusStateMock) SetProcessingBlock(processingBlock bool) { - cnsm.SetProcessingBlockCalled(processingBlock) -} - -// ConsensusGroupSize - -func (cnsm *ConsensusStateMock) ConsensusGroupSize() int { - return cnsm.ConsensusGroupSizeCalled() -} - -// SetThreshold - -func (cnsm *ConsensusStateMock) SetThreshold(subroundId int, threshold int) { - cnsm.SetThresholdCalled(subroundId, threshold) -} diff --git a/consensus/mock/contextMock.go b/consensus/mock/contextMock.go new file mode 100644 index 00000000000..0cdab606821 --- /dev/null +++ b/consensus/mock/contextMock.go @@ -0,0 +1,45 @@ +package mock + +import ( + "time" +) + +// ContextMock - +type ContextMock struct { + DoneFunc func() <-chan struct{} + DeadlineFunc func() (time.Time, bool) + ErrFunc func() error + ValueFunc func(key interface{}) interface{} +} + +// Done - +func (c *ContextMock) Done() <-chan struct{} { + if c.DoneFunc != nil { + return c.DoneFunc() + } + return nil +} + +// Deadline - +func (c *ContextMock) Deadline() (time.Time, bool) { + if c.DeadlineFunc != nil { + return c.DeadlineFunc() + } + return time.Time{}, false +} + +// Err - +func (c *ContextMock) Err() error { + if c.ErrFunc != nil { + return c.ErrFunc() + } + return nil +} + +// Value - +func (c *ContextMock) Value(key interface{}) interface{} { + if c.ValueFunc != nil { + return c.ValueFunc(key) + } + return nil +} diff --git a/consensus/mock/forkDetectorMock.go b/consensus/mock/forkDetectorMock.go deleted file mode 100644 index 6c1a4f70d5e..00000000000 --- a/consensus/mock/forkDetectorMock.go +++ /dev/null @@ -1,93 +0,0 @@ -package mock - -import ( - "github.com/multiversx/mx-chain-core-go/data" - "github.com/multiversx/mx-chain-go/process" -) - -// ForkDetectorMock - -type ForkDetectorMock struct { - AddHeaderCalled func(header data.HeaderHandler, hash []byte, state process.BlockHeaderState, selfNotarizedHeaders []data.HeaderHandler, selfNotarizedHeadersHashes [][]byte) error - RemoveHeaderCalled func(nonce uint64, hash []byte) - CheckForkCalled func() *process.ForkInfo - GetHighestFinalBlockNonceCalled func() uint64 - GetHighestFinalBlockHashCalled func() []byte - ProbableHighestNonceCalled func() uint64 - ResetForkCalled func() - GetNotarizedHeaderHashCalled func(nonce uint64) []byte - SetRollBackNonceCalled func(nonce uint64) - RestoreToGenesisCalled func() - ResetProbableHighestNonceCalled func() - SetFinalToLastCheckpointCalled func() -} - -// RestoreToGenesis - -func (fdm *ForkDetectorMock) RestoreToGenesis() { - fdm.RestoreToGenesisCalled() -} - -// AddHeader - -func (fdm *ForkDetectorMock) AddHeader(header data.HeaderHandler, hash []byte, state process.BlockHeaderState, selfNotarizedHeaders []data.HeaderHandler, selfNotarizedHeadersHashes [][]byte) error { - return fdm.AddHeaderCalled(header, hash, state, selfNotarizedHeaders, selfNotarizedHeadersHashes) -} - -// RemoveHeader - -func (fdm *ForkDetectorMock) RemoveHeader(nonce uint64, hash []byte) { - fdm.RemoveHeaderCalled(nonce, hash) -} - -// CheckFork - -func (fdm *ForkDetectorMock) CheckFork() *process.ForkInfo { - return fdm.CheckForkCalled() -} - -// GetHighestFinalBlockNonce - -func (fdm *ForkDetectorMock) GetHighestFinalBlockNonce() uint64 { - return fdm.GetHighestFinalBlockNonceCalled() -} - -// GetHighestFinalBlockHash - -func (fdm *ForkDetectorMock) GetHighestFinalBlockHash() []byte { - return fdm.GetHighestFinalBlockHashCalled() -} - -// ProbableHighestNonce - -func (fdm *ForkDetectorMock) ProbableHighestNonce() uint64 { - return fdm.ProbableHighestNonceCalled() -} - -// SetRollBackNonce - -func (fdm *ForkDetectorMock) SetRollBackNonce(nonce uint64) { - if fdm.SetRollBackNonceCalled != nil { - fdm.SetRollBackNonceCalled(nonce) - } -} - -// ResetFork - -func (fdm *ForkDetectorMock) ResetFork() { - fdm.ResetForkCalled() -} - -// GetNotarizedHeaderHash - -func (fdm *ForkDetectorMock) GetNotarizedHeaderHash(nonce uint64) []byte { - return fdm.GetNotarizedHeaderHashCalled(nonce) -} - -// ResetProbableHighestNonce - -func (fdm *ForkDetectorMock) ResetProbableHighestNonce() { - if fdm.ResetProbableHighestNonceCalled != nil { - fdm.ResetProbableHighestNonceCalled() - } -} - -// SetFinalToLastCheckpoint - -func (fdm *ForkDetectorMock) SetFinalToLastCheckpoint() { - if fdm.SetFinalToLastCheckpointCalled != nil { - fdm.SetFinalToLastCheckpointCalled() - } -} - -// IsInterfaceNil returns true if there is no value under the interface -func (fdm *ForkDetectorMock) IsInterfaceNil() bool { - return fdm == nil -} diff --git a/consensus/mock/headerIntegrityVerifierStub.go b/consensus/mock/headerIntegrityVerifierStub.go deleted file mode 100644 index 3d793b89924..00000000000 --- a/consensus/mock/headerIntegrityVerifierStub.go +++ /dev/null @@ -1,32 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-core-go/data" - -// HeaderIntegrityVerifierStub - -type HeaderIntegrityVerifierStub struct { - VerifyCalled func(header data.HeaderHandler) error - GetVersionCalled func(epoch uint32) string -} - -// Verify - -func (h *HeaderIntegrityVerifierStub) Verify(header data.HeaderHandler) error { - if h.VerifyCalled != nil { - return h.VerifyCalled(header) - } - - return nil -} - -// GetVersion - -func (h *HeaderIntegrityVerifierStub) GetVersion(epoch uint32) string { - if h.GetVersionCalled != nil { - return h.GetVersionCalled(epoch) - } - - return "version" -} - -// IsInterfaceNil - -func (h *HeaderIntegrityVerifierStub) IsInterfaceNil() bool { - return h == nil -} diff --git a/consensus/mock/headerSigVerifierStub.go b/consensus/mock/headerSigVerifierStub.go deleted file mode 100644 index b75b5615a12..00000000000 --- a/consensus/mock/headerSigVerifierStub.go +++ /dev/null @@ -1,52 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-core-go/data" - -// HeaderSigVerifierStub - -type HeaderSigVerifierStub struct { - VerifyRandSeedAndLeaderSignatureCalled func(header data.HeaderHandler) error - VerifySignatureCalled func(header data.HeaderHandler) error - VerifyRandSeedCalled func(header data.HeaderHandler) error - VerifyLeaderSignatureCalled func(header data.HeaderHandler) error -} - -// VerifyRandSeed - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeed(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedCalled != nil { - return hsvm.VerifyRandSeedCalled(header) - } - - return nil -} - -// VerifyRandSeedAndLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedAndLeaderSignatureCalled != nil { - return hsvm.VerifyRandSeedAndLeaderSignatureCalled(header) - } - - return nil -} - -// VerifySignature - -func (hsvm *HeaderSigVerifierStub) VerifySignature(header data.HeaderHandler) error { - if hsvm.VerifySignatureCalled != nil { - return hsvm.VerifySignatureCalled(header) - } - - return nil -} - -// VerifyLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyLeaderSignatureCalled != nil { - return hsvm.VerifyLeaderSignatureCalled(header) - } - - return nil -} - -// IsInterfaceNil - -func (hsvm *HeaderSigVerifierStub) IsInterfaceNil() bool { - return hsvm == nil -} diff --git a/consensus/mock/headersCacherStub.go b/consensus/mock/headersCacherStub.go deleted file mode 100644 index bc458a8235f..00000000000 --- a/consensus/mock/headersCacherStub.go +++ /dev/null @@ -1,105 +0,0 @@ -package mock - -import ( - "errors" - - "github.com/multiversx/mx-chain-core-go/data" -) - -// HeadersCacherStub - -type HeadersCacherStub struct { - AddCalled func(headerHash []byte, header data.HeaderHandler) - RemoveHeaderByHashCalled func(headerHash []byte) - RemoveHeaderByNonceAndShardIdCalled func(hdrNonce uint64, shardId uint32) - GetHeaderByNonceAndShardIdCalled func(hdrNonce uint64, shardId uint32) ([]data.HeaderHandler, [][]byte, error) - GetHeaderByHashCalled func(hash []byte) (data.HeaderHandler, error) - ClearCalled func() - RegisterHandlerCalled func(handler func(header data.HeaderHandler, shardHeaderHash []byte)) - NoncesCalled func(shardId uint32) []uint64 - LenCalled func() int - MaxSizeCalled func() int - GetNumHeadersCalled func(shardId uint32) int -} - -// AddHeader - -func (hcs *HeadersCacherStub) AddHeader(headerHash []byte, header data.HeaderHandler) { - if hcs.AddCalled != nil { - hcs.AddCalled(headerHash, header) - } -} - -// RemoveHeaderByHash - -func (hcs *HeadersCacherStub) RemoveHeaderByHash(headerHash []byte) { - if hcs.RemoveHeaderByHashCalled != nil { - hcs.RemoveHeaderByHashCalled(headerHash) - } -} - -// RemoveHeaderByNonceAndShardId - -func (hcs *HeadersCacherStub) RemoveHeaderByNonceAndShardId(hdrNonce uint64, shardId uint32) { - if hcs.RemoveHeaderByNonceAndShardIdCalled != nil { - hcs.RemoveHeaderByNonceAndShardIdCalled(hdrNonce, shardId) - } -} - -// GetHeadersByNonceAndShardId - -func (hcs *HeadersCacherStub) GetHeadersByNonceAndShardId(hdrNonce uint64, shardId uint32) ([]data.HeaderHandler, [][]byte, error) { - if hcs.GetHeaderByNonceAndShardIdCalled != nil { - return hcs.GetHeaderByNonceAndShardIdCalled(hdrNonce, shardId) - } - return nil, nil, errors.New("err") -} - -// GetHeaderByHash - -func (hcs *HeadersCacherStub) GetHeaderByHash(hash []byte) (data.HeaderHandler, error) { - if hcs.GetHeaderByHashCalled != nil { - return hcs.GetHeaderByHashCalled(hash) - } - return nil, nil -} - -// Clear - -func (hcs *HeadersCacherStub) Clear() { - if hcs.ClearCalled != nil { - hcs.ClearCalled() - } -} - -// RegisterHandler - -func (hcs *HeadersCacherStub) RegisterHandler(handler func(header data.HeaderHandler, shardHeaderHash []byte)) { - if hcs.RegisterHandlerCalled != nil { - hcs.RegisterHandlerCalled(handler) - } -} - -// Nonces - -func (hcs *HeadersCacherStub) Nonces(shardId uint32) []uint64 { - if hcs.NoncesCalled != nil { - return hcs.NoncesCalled(shardId) - } - return nil -} - -// Len - -func (hcs *HeadersCacherStub) Len() int { - return 0 -} - -// MaxSize - -func (hcs *HeadersCacherStub) MaxSize() int { - return 100 -} - -// IsInterfaceNil - -func (hcs *HeadersCacherStub) IsInterfaceNil() bool { - return hcs == nil -} - -// GetNumHeaders - -func (hcs *HeadersCacherStub) GetNumHeaders(shardId uint32) int { - if hcs.GetNumHeadersCalled != nil { - return hcs.GetNumHeadersCalled(shardId) - } - - return 0 -} diff --git a/consensus/mock/watchdogMock.go b/consensus/mock/watchdogMock.go index 15a153f50a0..1c026b4e8c4 100644 --- a/consensus/mock/watchdogMock.go +++ b/consensus/mock/watchdogMock.go @@ -6,10 +6,15 @@ import ( // WatchdogMock - type WatchdogMock struct { + SetCalled func(callback func(alarmID string), duration time.Duration, alarmID string) + StopCalled func(alarmID string) } // Set - func (w *WatchdogMock) Set(callback func(alarmID string), duration time.Duration, alarmID string) { + if w.SetCalled != nil { + w.SetCalled(callback, duration, alarmID) + } } // SetDefault - @@ -18,6 +23,9 @@ func (w *WatchdogMock) SetDefault(duration time.Duration, alarmID string) { // Stop - func (w *WatchdogMock) Stop(alarmID string) { + if w.StopCalled != nil { + w.StopCalled(alarmID) + } } // Reset - diff --git a/consensus/round/round.go b/consensus/round/round.go index 3b0ea17ce84..e4ed2e07d1d 100644 --- a/consensus/round/round.go +++ b/consensus/round/round.go @@ -6,6 +6,7 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/ntp" ) @@ -100,6 +101,14 @@ func (rnd *round) RemainingTime(startTime time.Time, maxTime time.Duration) time return remainingTime } +// RevertOneRound reverts the round index and time stamp by one round, used in case of a transition to new consensus +func (rnd *round) RevertOneRound() { + rnd.Lock() + rnd.index-- + rnd.timeStamp = rnd.timeStamp.Add(-rnd.timeDuration) + rnd.Unlock() +} + // IsInterfaceNil returns true if there is no value under the interface func (rnd *round) IsInterfaceNil() bool { return rnd == nil diff --git a/consensus/round/round_test.go b/consensus/round/round_test.go index ede509d7176..e0ece36b572 100644 --- a/consensus/round/round_test.go +++ b/consensus/round/round_test.go @@ -5,8 +5,11 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core/check" - "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/consensus/round" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/stretchr/testify/assert" ) @@ -28,7 +31,7 @@ func TestRound_NewRoundShouldWork(t *testing.T) { genesisTime := time.Now() - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} rnd, err := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, 0) @@ -41,7 +44,7 @@ func TestRound_UpdateRoundShouldNotChangeAnything(t *testing.T) { genesisTime := time.Now() - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} rnd, _ := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, 0) oldIndex := rnd.Index() @@ -61,7 +64,7 @@ func TestRound_UpdateRoundShouldAdvanceOneRound(t *testing.T) { genesisTime := time.Now() - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} rnd, _ := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, 0) oldIndex := rnd.Index() @@ -76,7 +79,7 @@ func TestRound_IndexShouldReturnFirstIndex(t *testing.T) { genesisTime := time.Now() - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} rnd, _ := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, 0) rnd.UpdateRound(genesisTime, genesisTime.Add(roundTimeDuration/2)) @@ -90,7 +93,7 @@ func TestRound_TimeStampShouldReturnTimeStampOfTheNextRound(t *testing.T) { genesisTime := time.Now() - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} rnd, _ := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, 0) rnd.UpdateRound(genesisTime, genesisTime.Add(roundTimeDuration+roundTimeDuration/2)) @@ -104,7 +107,7 @@ func TestRound_TimeDurationShouldReturnTheDurationOfOneRound(t *testing.T) { genesisTime := time.Now() - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} rnd, _ := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, 0) timeDuration := rnd.TimeDuration() @@ -117,7 +120,7 @@ func TestRound_RemainingTimeInCurrentRoundShouldReturnPositiveValue(t *testing.T genesisTime := time.Unix(0, 0) - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} timeElapsed := int64(roundTimeDuration - 1) @@ -138,7 +141,7 @@ func TestRound_RemainingTimeInCurrentRoundShouldReturnNegativeValue(t *testing.T genesisTime := time.Unix(0, 0) - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} timeElapsed := int64(roundTimeDuration + 1) @@ -153,3 +156,38 @@ func TestRound_RemainingTimeInCurrentRoundShouldReturnNegativeValue(t *testing.T assert.Equal(t, time.Duration(int64(rnd.TimeDuration())-timeElapsed), remainingTime) assert.True(t, remainingTime < 0) } + +func TestRound_RevertOneRound(t *testing.T) { + t.Parallel() + + genesisTime := time.Now() + + syncTimerMock := &consensusMocks.SyncTimerMock{} + + startRound := int64(10) + rnd, _ := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, startRound) + index := rnd.Index() + require.Equal(t, startRound, index) + + rnd.RevertOneRound() + index = rnd.Index() + require.Equal(t, startRound-1, index) +} + +func TestRound_BeforeGenesis(t *testing.T) { + t.Parallel() + + genesisTime := time.Now() + + syncTimerMock := &consensusMocks.SyncTimerMock{} + + startRound := int64(-1) + rnd, _ := round.NewRound(genesisTime, genesisTime, roundTimeDuration, syncTimerMock, startRound) + require.True(t, rnd.BeforeGenesis()) + + time.Sleep(roundTimeDuration * 2) + currentTime := time.Now() + + rnd.UpdateRound(genesisTime, currentTime) + require.False(t, rnd.BeforeGenesis()) +} diff --git a/consensus/spos/bls/blsWorker.go b/consensus/spos/bls/blsWorker.go index 456d4e8b1d8..ef0cdfc35ee 100644 --- a/consensus/spos/bls/blsWorker.go +++ b/consensus/spos/bls/blsWorker.go @@ -5,7 +5,7 @@ import ( "github.com/multiversx/mx-chain-go/consensus/spos" ) -// peerMaxMessagesPerSec defines how many messages can be propagated by a pid in a round. The value was chosen by +// PeerMaxMessagesPerSec defines how many messages can be propagated by a pid in a round. The value was chosen by // following the next premises: // 1. a leader can propagate as maximum as 3 messages per round: proposed header block + proposed body + final info; // 2. due to the fact that a delayed signature of the proposer (from previous round) can be received in the current round @@ -16,15 +16,15 @@ import ( // // Validators only send one signature message in a round, treating the edge case of a delayed message, will need at most // 2 messages per round (which is ok as it is below the set value of 5) -const peerMaxMessagesPerSec = uint32(6) +const PeerMaxMessagesPerSec = uint32(6) -// defaultMaxNumOfMessageTypeAccepted represents the maximum number of the same message type accepted in one round to be +// DefaultMaxNumOfMessageTypeAccepted represents the maximum number of the same message type accepted in one round to be // received from the same public key for the default message types -const defaultMaxNumOfMessageTypeAccepted = uint32(1) +const DefaultMaxNumOfMessageTypeAccepted = uint32(1) -// maxNumOfMessageTypeSignatureAccepted represents the maximum number of the signature message type accepted in one round to be +// MaxNumOfMessageTypeSignatureAccepted represents the maximum number of the signature message type accepted in one round to be // received from the same public key -const maxNumOfMessageTypeSignatureAccepted = uint32(2) +const MaxNumOfMessageTypeSignatureAccepted = uint32(2) // worker defines the data needed by spos to communicate between nodes which are in the validators group type worker struct { @@ -52,17 +52,17 @@ func (wrk *worker) InitReceivedMessages() map[consensus.MessageType][]*consensus // GetMaxMessagesInARoundPerPeer returns the maximum number of messages a peer can send per round for BLS func (wrk *worker) GetMaxMessagesInARoundPerPeer() uint32 { - return peerMaxMessagesPerSec + return PeerMaxMessagesPerSec } // GetStringValue gets the name of the messageType func (wrk *worker) GetStringValue(messageType consensus.MessageType) string { - return getStringValue(messageType) + return GetStringValue(messageType) } // GetSubroundName gets the subround name for the subround id provided func (wrk *worker) GetSubroundName(subroundId int) string { - return getSubroundName(subroundId) + return GetSubroundName(subroundId) } // IsMessageWithBlockBodyAndHeader returns if the current messageType is about block body and header @@ -148,13 +148,18 @@ func (wrk *worker) CanProceed(consensusState *spos.ConsensusState, msgType conse return false } +// GetMessageTypeBlockHeader returns the message type for block header +func (wrk *worker) GetMessageTypeBlockHeader() consensus.MessageType { + return MtBlockHeader +} + // GetMaxNumOfMessageTypeAccepted returns the maximum number of accepted consensus message types per round, per public key func (wrk *worker) GetMaxNumOfMessageTypeAccepted(msgType consensus.MessageType) uint32 { if msgType == MtSignature { - return maxNumOfMessageTypeSignatureAccepted + return MaxNumOfMessageTypeSignatureAccepted } - return defaultMaxNumOfMessageTypeAccepted + return DefaultMaxNumOfMessageTypeAccepted } // IsInterfaceNil returns true if there is no value under the interface diff --git a/consensus/spos/bls/blsWorker_test.go b/consensus/spos/bls/blsWorker_test.go index 6786b96cde8..8d39b02e5f1 100644 --- a/consensus/spos/bls/blsWorker_test.go +++ b/consensus/spos/bls/blsWorker_test.go @@ -4,68 +4,14 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" - "github.com/multiversx/mx-chain-go/testscommon" - "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" ) -func createEligibleList(size int) []string { - eligibleList := make([]string, 0) - for i := 0; i < size; i++ { - eligibleList = append(eligibleList, string([]byte{byte(i + 65)})) - } - return eligibleList -} - -func initConsensusState() *spos.ConsensusState { - return initConsensusStateWithKeysHandler(&testscommon.KeysHandlerStub{}) -} - -func initConsensusStateWithKeysHandler(keysHandler consensus.KeysHandler) *spos.ConsensusState { - consensusGroupSize := 9 - eligibleList := createEligibleList(consensusGroupSize) - - eligibleNodesPubKeys := make(map[string]struct{}) - for _, key := range eligibleList { - eligibleNodesPubKeys[key] = struct{}{} - } - - indexLeader := 1 - rcns, _ := spos.NewRoundConsensus( - eligibleNodesPubKeys, - consensusGroupSize, - eligibleList[indexLeader], - keysHandler, - ) - - rcns.SetConsensusGroup(eligibleList) - rcns.ResetRoundState() - - pBFTThreshold := consensusGroupSize*2/3 + 1 - pBFTFallbackThreshold := consensusGroupSize*1/2 + 1 - - rthr := spos.NewRoundThreshold() - rthr.SetThreshold(1, 1) - rthr.SetThreshold(2, pBFTThreshold) - rthr.SetFallbackThreshold(1, 1) - rthr.SetFallbackThreshold(2, pBFTFallbackThreshold) - - rstatus := spos.NewRoundStatus() - rstatus.ResetRoundStatus() - - cns := spos.NewConsensusState( - rcns, - rthr, - rstatus, - ) - - cns.Data = []byte("X") - cns.RoundIndex = 0 - return cns -} - func TestWorker_NewConsensusServiceShouldWork(t *testing.T) { t.Parallel() @@ -121,7 +67,7 @@ func TestWorker_CanProceedWithSrStartRoundFinishedForMtBlockBodyAndHeaderShouldW blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrStartRound, spos.SsFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockBodyAndHeader) @@ -133,7 +79,7 @@ func TestWorker_CanProceedWithSrStartRoundNotFinishedForMtBlockBodyAndHeaderShou blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrStartRound, spos.SsNotFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockBodyAndHeader) @@ -145,7 +91,7 @@ func TestWorker_CanProceedWithSrStartRoundFinishedForMtBlockBodyShouldWork(t *te blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrStartRound, spos.SsFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockBody) @@ -157,7 +103,7 @@ func TestWorker_CanProceedWithSrStartRoundNotFinishedForMtBlockBodyShouldNotWork blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrStartRound, spos.SsNotFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockBody) @@ -169,7 +115,7 @@ func TestWorker_CanProceedWithSrStartRoundFinishedForMtBlockHeaderShouldWork(t * blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrStartRound, spos.SsFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockHeader) @@ -181,7 +127,7 @@ func TestWorker_CanProceedWithSrStartRoundNotFinishedForMtBlockHeaderShouldNotWo blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrStartRound, spos.SsNotFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockHeader) @@ -193,7 +139,7 @@ func TestWorker_CanProceedWithSrBlockFinishedForMtBlockHeaderShouldWork(t *testi blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrBlock, spos.SsFinished) canProceed := blsService.CanProceed(consensusState, bls.MtSignature) @@ -205,7 +151,7 @@ func TestWorker_CanProceedWithSrBlockRoundNotFinishedForMtBlockHeaderShouldNotWo blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrBlock, spos.SsNotFinished) canProceed := blsService.CanProceed(consensusState, bls.MtSignature) @@ -217,7 +163,7 @@ func TestWorker_CanProceedWithSrSignatureFinishedForMtBlockHeaderFinalInfoShould blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrSignature, spos.SsFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockHeaderFinalInfo) @@ -229,7 +175,7 @@ func TestWorker_CanProceedWithSrSignatureRoundNotFinishedForMtBlockHeaderFinalIn blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() consensusState.SetStatus(bls.SrSignature, spos.SsNotFinished) canProceed := blsService.CanProceed(consensusState, bls.MtBlockHeaderFinalInfo) @@ -240,7 +186,7 @@ func TestWorker_CanProceedWitUnkownMessageTypeShouldNotWork(t *testing.T) { t.Parallel() blsService, _ := bls.NewConsensusService() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() canProceed := blsService.CanProceed(consensusState, -1) assert.False(t, canProceed) diff --git a/consensus/spos/bls/constants.go b/consensus/spos/bls/constants.go index 166abe70b65..88667da3003 100644 --- a/consensus/spos/bls/constants.go +++ b/consensus/spos/bls/constants.go @@ -2,11 +2,8 @@ package bls import ( "github.com/multiversx/mx-chain-go/consensus" - logger "github.com/multiversx/mx-chain-logger-go" ) -var log = logger.GetOrCreate("consensus/spos/bls") - const ( // SrStartRound defines ID of Subround "Start round" SrStartRound = iota @@ -36,36 +33,6 @@ const ( MtInvalidSigners ) -// waitingAllSigsMaxTimeThreshold specifies the max allocated time for waiting all signatures from the total time of the subround signature -const waitingAllSigsMaxTimeThreshold = 0.5 - -// processingThresholdPercent specifies the max allocated time for processing the block as a percentage of the total time of the round -const processingThresholdPercent = 85 - -// srStartStartTime specifies the start time, from the total time of the round, of Subround Start -const srStartStartTime = 0.0 - -// srEndStartTime specifies the end time, from the total time of the round, of Subround Start -const srStartEndTime = 0.05 - -// srBlockStartTime specifies the start time, from the total time of the round, of Subround Block -const srBlockStartTime = 0.05 - -// srBlockEndTime specifies the end time, from the total time of the round, of Subround Block -const srBlockEndTime = 0.25 - -// srSignatureStartTime specifies the start time, from the total time of the round, of Subround Signature -const srSignatureStartTime = 0.25 - -// srSignatureEndTime specifies the end time, from the total time of the round, of Subround Signature -const srSignatureEndTime = 0.85 - -// srEndStartTime specifies the start time, from the total time of the round, of Subround End -const srEndStartTime = 0.85 - -// srEndEndTime specifies the end time, from the total time of the round, of Subround End -const srEndEndTime = 0.95 - const ( // BlockBodyAndHeaderStringValue represents the string to be used to identify a block body and a block header BlockBodyAndHeaderStringValue = "(BLOCK_BODY_AND_HEADER)" @@ -89,7 +56,8 @@ const ( BlockDefaultStringValue = "Undefined message type" ) -func getStringValue(msgType consensus.MessageType) string { +// GetStringValue returns the string value of a given MessageType +func GetStringValue(msgType consensus.MessageType) string { switch msgType { case MtBlockBodyAndHeader: return BlockBodyAndHeaderStringValue @@ -108,8 +76,8 @@ func getStringValue(msgType consensus.MessageType) string { } } -// getSubroundName returns the name of each Subround from a given Subround ID -func getSubroundName(subroundId int) string { +// GetSubroundName returns the name of each Subround from a given Subround ID +func GetSubroundName(subroundId int) string { switch subroundId { case SrStartRound: return "(START_ROUND)" diff --git a/consensus/spos/bls/errors.go b/consensus/spos/bls/errors.go deleted file mode 100644 index b840f9e2c85..00000000000 --- a/consensus/spos/bls/errors.go +++ /dev/null @@ -1,6 +0,0 @@ -package bls - -import "errors" - -// ErrNilSentSignatureTracker defines the error for setting a nil SentSignatureTracker -var ErrNilSentSignatureTracker = errors.New("nil sent signature tracker") diff --git a/consensus/spos/bls/proxy/errors.go b/consensus/spos/bls/proxy/errors.go new file mode 100644 index 00000000000..4036ecf1c63 --- /dev/null +++ b/consensus/spos/bls/proxy/errors.go @@ -0,0 +1,38 @@ +package proxy + +import ( + "errors" +) + +// ErrNilChronologyHandler is the error returned when the chronology handler is nil +var ErrNilChronologyHandler = errors.New("nil chronology handler") + +// ErrNilConsensusCoreHandler is the error returned when the consensus core handler is nil +var ErrNilConsensusCoreHandler = errors.New("nil consensus core handler") + +// ErrNilConsensusState is the error returned when the consensus state is nil +var ErrNilConsensusState = errors.New("nil consensus state") + +// ErrNilWorker is the error returned when the worker is nil +var ErrNilWorker = errors.New("nil worker") + +// ErrNilSignatureThrottler is the error returned when the signature throttler is nil +var ErrNilSignatureThrottler = errors.New("nil signature throttler") + +// ErrNilAppStatusHandler is the error returned when the app status handler is nil +var ErrNilAppStatusHandler = errors.New("nil app status handler") + +// ErrNilOutportHandler is the error returned when the outport handler is nil +var ErrNilOutportHandler = errors.New("nil outport handler") + +// ErrNilSentSignatureTracker is the error returned when the sent signature tracker is nil +var ErrNilSentSignatureTracker = errors.New("nil sent signature tracker") + +// ErrNilChainID is the error returned when the chain ID is nil +var ErrNilChainID = errors.New("nil chain ID") + +// ErrNilCurrentPid is the error returned when the current PID is nil +var ErrNilCurrentPid = errors.New("nil current PID") + +// ErrNilEnableEpochsHandler is the error returned when the enable epochs handler is nil +var ErrNilEnableEpochsHandler = errors.New("nil enable epochs handler") diff --git a/consensus/spos/bls/proxy/subroundsHandler.go b/consensus/spos/bls/proxy/subroundsHandler.go new file mode 100644 index 00000000000..52b15c4ec9b --- /dev/null +++ b/consensus/spos/bls/proxy/subroundsHandler.go @@ -0,0 +1,204 @@ +package proxy + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + logger "github.com/multiversx/mx-chain-logger-go" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + v1 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v1" + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + "github.com/multiversx/mx-chain-go/factory" + "github.com/multiversx/mx-chain-go/outport" +) + +var log = logger.GetOrCreate("consensus/spos/bls/proxy") + +// SubroundsHandlerArgs struct contains the needed data for the SubroundsHandler +type SubroundsHandlerArgs struct { + Chronology consensus.ChronologyHandler + ConsensusCoreHandler spos.ConsensusCoreHandler + ConsensusState spos.ConsensusStateHandler + Worker factory.ConsensusWorker + SignatureThrottler core.Throttler + AppStatusHandler core.AppStatusHandler + OutportHandler outport.OutportHandler + SentSignatureTracker spos.SentSignaturesTracker + EnableEpochsHandler core.EnableEpochsHandler + ChainID []byte + CurrentPid core.PeerID +} + +// subroundsFactory defines the methods needed to generate the subrounds +type subroundsFactory interface { + GenerateSubrounds(epoch uint32) error + SetOutportHandler(driver outport.OutportHandler) + IsInterfaceNil() bool +} + +type consensusStateMachineType int + +// SubroundsHandler struct contains the needed data for the SubroundsHandler +type SubroundsHandler struct { + chronology consensus.ChronologyHandler + consensusCoreHandler spos.ConsensusCoreHandler + consensusState spos.ConsensusStateHandler + worker factory.ConsensusWorker + signatureThrottler core.Throttler + appStatusHandler core.AppStatusHandler + outportHandler outport.OutportHandler + sentSignatureTracker spos.SentSignaturesTracker + enableEpochsHandler core.EnableEpochsHandler + chainID []byte + currentPid core.PeerID + currentConsensusType consensusStateMachineType +} + +// EpochConfirmed is called when the epoch is confirmed (this is registered as callback) +func (s *SubroundsHandler) EpochConfirmed(epoch uint32, _ uint64) { + err := s.initSubroundsForEpoch(epoch) + if err != nil { + log.Error("SubroundsHandler.EpochConfirmed: cannot initialize subrounds", "error", err) + } +} + +const ( + consensusNone consensusStateMachineType = iota + consensusV1 + consensusV2 +) + +// NewSubroundsHandler creates a new SubroundsHandler object +func NewSubroundsHandler(args *SubroundsHandlerArgs) (*SubroundsHandler, error) { + err := checkArgs(args) + if err != nil { + return nil, err + } + + subroundHandler := &SubroundsHandler{ + chronology: args.Chronology, + consensusCoreHandler: args.ConsensusCoreHandler, + consensusState: args.ConsensusState, + worker: args.Worker, + signatureThrottler: args.SignatureThrottler, + appStatusHandler: args.AppStatusHandler, + outportHandler: args.OutportHandler, + sentSignatureTracker: args.SentSignatureTracker, + enableEpochsHandler: args.EnableEpochsHandler, + chainID: args.ChainID, + currentPid: args.CurrentPid, + currentConsensusType: consensusNone, + } + + subroundHandler.consensusCoreHandler.EpochNotifier().RegisterNotifyHandler(subroundHandler) + + return subroundHandler, nil +} + +func checkArgs(args *SubroundsHandlerArgs) error { + if check.IfNil(args.Chronology) { + return ErrNilChronologyHandler + } + if check.IfNil(args.ConsensusCoreHandler) { + return ErrNilConsensusCoreHandler + } + if check.IfNil(args.ConsensusState) { + return ErrNilConsensusState + } + if check.IfNil(args.Worker) { + return ErrNilWorker + } + if check.IfNil(args.SignatureThrottler) { + return ErrNilSignatureThrottler + } + if check.IfNil(args.AppStatusHandler) { + return ErrNilAppStatusHandler + } + if check.IfNil(args.OutportHandler) { + return ErrNilOutportHandler + } + if check.IfNil(args.SentSignatureTracker) { + return ErrNilSentSignatureTracker + } + if check.IfNil(args.EnableEpochsHandler) { + return ErrNilEnableEpochsHandler + } + if args.ChainID == nil { + return ErrNilChainID + } + if len(args.CurrentPid) == 0 { + return ErrNilCurrentPid + } + // outport handler can be nil if not configured so no need to check it + + return nil +} + +// Start starts the sub-rounds handler +func (s *SubroundsHandler) Start(epoch uint32) error { + return s.initSubroundsForEpoch(epoch) +} + +func (s *SubroundsHandler) initSubroundsForEpoch(epoch uint32) error { + var err error + var fct subroundsFactory + if s.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, epoch) { + if s.currentConsensusType == consensusV2 { + return nil + } + + s.currentConsensusType = consensusV2 + fct, err = v2.NewSubroundsFactory( + s.consensusCoreHandler, + s.consensusState, + s.worker, + s.chainID, + s.currentPid, + s.appStatusHandler, + s.sentSignatureTracker, + s.signatureThrottler, + s.outportHandler, + ) + } else { + if s.currentConsensusType == consensusV1 { + return nil + } + + s.currentConsensusType = consensusV1 + fct, err = v1.NewSubroundsFactory( + s.consensusCoreHandler, + s.consensusState, + s.worker, + s.chainID, + s.currentPid, + s.appStatusHandler, + s.sentSignatureTracker, + s.outportHandler, + ) + } + if err != nil { + return err + } + + err = s.chronology.Close() + if err != nil { + log.Warn("SubroundsHandler.initSubroundsForEpoch: cannot close the chronology", "error", err) + } + + err = fct.GenerateSubrounds(epoch) + if err != nil { + return err + } + + log.Debug("SubroundsHandler.initSubroundsForEpoch: reset consensus round state") + s.worker.ResetConsensusRoundState() + s.chronology.StartRounds() + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (s *SubroundsHandler) IsInterfaceNil() bool { + return s == nil +} diff --git a/consensus/spos/bls/proxy/subroundsHandler_test.go b/consensus/spos/bls/proxy/subroundsHandler_test.go new file mode 100644 index 00000000000..8367968260e --- /dev/null +++ b/consensus/spos/bls/proxy/subroundsHandler_test.go @@ -0,0 +1,475 @@ +package proxy + +import ( + "sync/atomic" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/require" + + mock2 "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/bootstrapperStubs" + "github.com/multiversx/mx-chain-go/testscommon/common" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" + mock "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" + outportStub "github.com/multiversx/mx-chain-go/testscommon/outport" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +func getDefaultArgumentsSubroundHandler() (*SubroundsHandlerArgs, *spos.ConsensusCore) { + x := make(chan bool) + chronology := &consensus.ChronologyHandlerMock{} + epochsEnable := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} + epochStartNotifier := &mock.EpochStartNotifierStub{} + consensusState := &consensus.ConsensusStateMock{} + epochNotifier := &epochNotifierMock.EpochNotifierStub{} + worker := &consensus.SposWorkerMock{ + RemoveAllReceivedMessagesCallsCalled: func() {}, + GetConsensusStateChangedChannelsCalled: func() chan bool { + return x + }, + } + antiFloodHandler := &mock2.P2PAntifloodHandlerStub{} + handlerArgs := &SubroundsHandlerArgs{ + Chronology: chronology, + ConsensusState: consensusState, + Worker: worker, + SignatureThrottler: &common.ThrottlerStub{}, + AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, + OutportHandler: &outportStub.OutportStub{}, + SentSignatureTracker: &testscommon.SentSignatureTrackerStub{}, + EnableEpochsHandler: epochsEnable, + ChainID: []byte("chainID"), + CurrentPid: "peerID", + } + + consensusCore := &spos.ConsensusCore{} + consensusCore.SetEpochStartNotifier(epochStartNotifier) + consensusCore.SetBlockchain(&testscommon.ChainHandlerStub{}) + consensusCore.SetBlockProcessor(&testscommon.BlockProcessorStub{}) + consensusCore.SetBootStrapper(&bootstrapperStubs.BootstrapperStub{}) + consensusCore.SetBroadcastMessenger(&consensus.BroadcastMessengerMock{}) + consensusCore.SetChronology(chronology) + consensusCore.SetAntifloodHandler(antiFloodHandler) + consensusCore.SetHasher(&testscommon.HasherStub{}) + consensusCore.SetMarshalizer(&testscommon.MarshallerStub{}) + consensusCore.SetMultiSignerContainer(&cryptoMocks.MultiSignerContainerStub{ + GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { + return &cryptoMocks.MultisignerMock{}, nil + }, + }) + consensusCore.SetRoundHandler(&consensus.RoundHandlerMock{}) + consensusCore.SetShardCoordinator(&testscommon.ShardsCoordinatorMock{}) + consensusCore.SetSyncTimer(&testscommon.SyncTimerStub{}) + consensusCore.SetNodesCoordinator(&shardingMocks.NodesCoordinatorMock{}) + consensusCore.SetPeerHonestyHandler(&testscommon.PeerHonestyHandlerStub{}) + consensusCore.SetHeaderSigVerifier(&consensus.HeaderSigVerifierMock{}) + consensusCore.SetFallbackHeaderValidator(&testscommon.FallBackHeaderValidatorStub{}) + consensusCore.SetNodeRedundancyHandler(&mock2.NodeRedundancyHandlerStub{}) + consensusCore.SetScheduledProcessor(&consensus.ScheduledProcessorStub{}) + consensusCore.SetMessageSigningHandler(&mock2.MessageSigningHandlerStub{}) + consensusCore.SetPeerBlacklistHandler(&mock2.PeerBlacklistHandlerStub{}) + consensusCore.SetSigningHandler(&consensus.SigningHandlerStub{}) + consensusCore.SetEnableEpochsHandler(epochsEnable) + consensusCore.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{}) + consensusCore.SetEpochNotifier(epochNotifier) + consensusCore.SetInvalidSignersCache(&consensus.InvalidSignersCacheMock{}) + handlerArgs.ConsensusCoreHandler = consensusCore + + return handlerArgs, consensusCore +} + +func TestNewSubroundsHandler(t *testing.T) { + t.Parallel() + + t.Run("nil chronology should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.Chronology = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilChronologyHandler, err) + require.Nil(t, sh) + }) + t.Run("nil consensus core should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.ConsensusCoreHandler = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilConsensusCoreHandler, err) + require.Nil(t, sh) + }) + t.Run("nil consensus state should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.ConsensusState = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilConsensusState, err) + require.Nil(t, sh) + }) + t.Run("nil worker should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.Worker = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilWorker, err) + require.Nil(t, sh) + }) + t.Run("nil signature throttler should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.SignatureThrottler = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilSignatureThrottler, err) + require.Nil(t, sh) + }) + t.Run("nil app status handler should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.AppStatusHandler = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilAppStatusHandler, err) + require.Nil(t, sh) + }) + t.Run("nil outport handler should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.OutportHandler = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilOutportHandler, err) + require.Nil(t, sh) + }) + t.Run("nil sent signature tracker should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.SentSignatureTracker = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilSentSignatureTracker, err) + require.Nil(t, sh) + }) + t.Run("nil enable epochs handler should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.EnableEpochsHandler = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilEnableEpochsHandler, err) + require.Nil(t, sh) + }) + t.Run("nil chain ID should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.ChainID = nil + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilChainID, err) + require.Nil(t, sh) + }) + t.Run("empty current PID should error", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + handlerArgs.CurrentPid = "" + sh, err := NewSubroundsHandler(handlerArgs) + require.Equal(t, ErrNilCurrentPid, err) + require.Nil(t, sh) + }) + t.Run("OK", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + }) +} + +func TestSubroundsHandler_initSubroundsForEpoch(t *testing.T) { + t.Parallel() + + t.Run("equivalent messages not enabled, with previous consensus type not consensusV1", func(t *testing.T) { + t.Parallel() + + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + // first call on register to EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) + sh.currentConsensusType = consensusNone + + err = sh.initSubroundsForEpoch(0) + require.Nil(t, err) + require.Equal(t, consensusV1, sh.currentConsensusType) + require.Equal(t, int32(2), startCalled.Load()) + }) + t.Run("equivalent messages not enabled, with previous consensus type consensusV1", func(t *testing.T) { + t.Parallel() + + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + // first call on register to EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) + sh.currentConsensusType = consensusV1 + + err = sh.initSubroundsForEpoch(0) + require.Nil(t, err) + require.Equal(t, consensusV1, sh.currentConsensusType) + require.Equal(t, int32(1), startCalled.Load()) + + }) + t.Run("equivalent messages enabled, with previous consensus type consensusNone", func(t *testing.T) { + t.Parallel() + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + // first call on register to EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) + sh.currentConsensusType = consensusNone + + err = sh.initSubroundsForEpoch(0) + require.Nil(t, err) + require.Equal(t, consensusV2, sh.currentConsensusType) + require.Equal(t, int32(2), startCalled.Load()) + }) + t.Run("equivalent messages enabled, with previous consensus type consensusV1", func(t *testing.T) { + t.Parallel() + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + // first call on register to EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) + sh.currentConsensusType = consensusV1 + + err = sh.initSubroundsForEpoch(0) + require.Nil(t, err) + require.Equal(t, consensusV2, sh.currentConsensusType) + require.Equal(t, int32(2), startCalled.Load()) + }) + t.Run("equivalent messages enabled, with previous consensus type consensusV2", func(t *testing.T) { + t.Parallel() + + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + // first call on register to EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) + sh.currentConsensusType = consensusV2 + + err = sh.initSubroundsForEpoch(0) + require.Nil(t, err) + require.Equal(t, consensusV2, sh.currentConsensusType) + require.Equal(t, int32(1), startCalled.Load()) + }) +} + +func TestSubroundsHandler_Start(t *testing.T) { + t.Parallel() + + // the Start is tested via initSubroundsForEpoch, adding one of the test cases here as well + t.Run("equivalent messages not enabled, with previous consensus type not consensusV1", func(t *testing.T) { + t.Parallel() + + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + // first call on init of EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) + sh.currentConsensusType = consensusNone + + err = sh.Start(0) + require.Nil(t, err) + require.Equal(t, consensusV1, sh.currentConsensusType) + require.Equal(t, int32(2), startCalled.Load()) + }) +} + +func TestSubroundsHandler_IsInterfaceNil(t *testing.T) { + t.Parallel() + + t.Run("nil handler", func(t *testing.T) { + t.Parallel() + + var sh *SubroundsHandler + require.True(t, sh.IsInterfaceNil()) + }) + t.Run("not nil handler", func(t *testing.T) { + t.Parallel() + + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + + require.False(t, sh.IsInterfaceNil()) + }) +} + +func TestSubroundsHandler_EpochConfirmed(t *testing.T) { + t.Parallel() + + t.Run("nil handler does not panic", func(t *testing.T) { + t.Parallel() + + defer func() { + if r := recover(); r != nil { + t.Errorf("The code panicked") + } + }() + handlerArgs, _ := getDefaultArgumentsSubroundHandler() + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + sh.EpochConfirmed(0, 0) + }) + + // tested through initSubroundsForEpoch + t.Run("OK", func(t *testing.T) { + t.Parallel() + + startCalled := atomic.Int32{} + handlerArgs, consensusCore := getDefaultArgumentsSubroundHandler() + chronology := &consensus.ChronologyHandlerMock{ + StartRoundCalled: func() { + startCalled.Add(1) + }, + } + enableEpoch := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + } + handlerArgs.Chronology = chronology + handlerArgs.EnableEpochsHandler = enableEpoch + consensusCore.SetEnableEpochsHandler(enableEpoch) + consensusCore.SetChronology(chronology) + + sh, err := NewSubroundsHandler(handlerArgs) + require.Nil(t, err) + require.NotNil(t, sh) + // first call on register to EpochNotifier + require.Equal(t, int32(1), startCalled.Load()) + + sh.currentConsensusType = consensusNone + sh.EpochConfirmed(0, 0) + require.Nil(t, err) + require.Equal(t, consensusV1, sh.currentConsensusType) + require.Equal(t, int32(2), startCalled.Load()) + }) +} diff --git a/consensus/spos/bls/blsSubroundsFactory.go b/consensus/spos/bls/v1/blsSubroundsFactory.go similarity index 78% rename from consensus/spos/bls/blsSubroundsFactory.go rename to consensus/spos/bls/v1/blsSubroundsFactory.go index aeb64a5775a..385d1603a5a 100644 --- a/consensus/spos/bls/blsSubroundsFactory.go +++ b/consensus/spos/bls/v1/blsSubroundsFactory.go @@ -1,11 +1,13 @@ -package bls +package v1 import ( "time" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" "github.com/multiversx/mx-chain-go/outport" ) @@ -13,7 +15,7 @@ import ( // functionality type factory struct { consensusCore spos.ConsensusCoreHandler - consensusState *spos.ConsensusState + consensusState spos.ConsensusStateHandler worker spos.WorkerHandler appStatusHandler core.AppStatusHandler @@ -26,13 +28,15 @@ type factory struct { // NewSubroundsFactory creates a new consensusState object func NewSubroundsFactory( consensusDataContainer spos.ConsensusCoreHandler, - consensusState *spos.ConsensusState, + consensusState spos.ConsensusStateHandler, worker spos.WorkerHandler, chainID []byte, currentPid core.PeerID, appStatusHandler core.AppStatusHandler, sentSignaturesTracker spos.SentSignaturesTracker, + outportHandler outport.OutportHandler, ) (*factory, error) { + // no need to check the outportHandler, it can be nil err := checkNewFactoryParams( consensusDataContainer, consensusState, @@ -53,6 +57,7 @@ func NewSubroundsFactory( chainID: chainID, currentPid: currentPid, sentSignaturesTracker: sentSignaturesTracker, + outportHandler: outportHandler, } return &fct, nil @@ -60,7 +65,7 @@ func NewSubroundsFactory( func checkNewFactoryParams( container spos.ConsensusCoreHandler, - state *spos.ConsensusState, + state spos.ConsensusStateHandler, worker spos.WorkerHandler, chainID []byte, appStatusHandler core.AppStatusHandler, @@ -70,7 +75,7 @@ func checkNewFactoryParams( if err != nil { return err } - if state == nil { + if check.IfNil(state) { return spos.ErrNilConsensusState } if check.IfNil(worker) { @@ -95,10 +100,11 @@ func (fct *factory) SetOutportHandler(driver outport.OutportHandler) { } // GenerateSubrounds will generate the subrounds used in BLS Cns -func (fct *factory) GenerateSubrounds() error { +func (fct *factory) GenerateSubrounds(_ uint32) error { fct.initConsensusThreshold() fct.consensusCore.Chronology().RemoveAllSubrounds() fct.worker.RemoveAllReceivedMessagesCalls() + fct.worker.RemoveAllReceivedHeaderHandlers() err := fct.generateStartRoundSubround() if err != nil { @@ -130,11 +136,11 @@ func (fct *factory) getTimeDuration() time.Duration { func (fct *factory) generateStartRoundSubround() error { subround, err := spos.NewSubround( -1, - SrStartRound, - SrBlock, + bls.SrStartRound, + bls.SrBlock, int64(float64(fct.getTimeDuration())*srStartStartTime), int64(float64(fct.getTimeDuration())*srStartEndTime), - getSubroundName(SrStartRound), + bls.GetSubroundName(bls.SrStartRound), fct.consensusState, fct.worker.GetConsensusStateChangedChannel(), fct.worker.ExecuteStoredMessages, @@ -171,12 +177,12 @@ func (fct *factory) generateStartRoundSubround() error { func (fct *factory) generateBlockSubround() error { subround, err := spos.NewSubround( - SrStartRound, - SrBlock, - SrSignature, + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, int64(float64(fct.getTimeDuration())*srBlockStartTime), int64(float64(fct.getTimeDuration())*srBlockEndTime), - getSubroundName(SrBlock), + bls.GetSubroundName(bls.SrBlock), fct.consensusState, fct.worker.GetConsensusStateChangedChannel(), fct.worker.ExecuteStoredMessages, @@ -198,9 +204,10 @@ func (fct *factory) generateBlockSubround() error { return err } - fct.worker.AddReceivedMessageCall(MtBlockBodyAndHeader, subroundBlockInstance.receivedBlockBodyAndHeader) - fct.worker.AddReceivedMessageCall(MtBlockBody, subroundBlockInstance.receivedBlockBody) - fct.worker.AddReceivedMessageCall(MtBlockHeader, subroundBlockInstance.receivedBlockHeader) + fct.worker.AddReceivedMessageCall(bls.MtBlockBodyAndHeader, subroundBlockInstance.receivedBlockBodyAndHeader) + fct.worker.AddReceivedMessageCall(bls.MtBlockBody, subroundBlockInstance.receivedBlockBody) + fct.worker.AddReceivedMessageCall(bls.MtBlockHeader, subroundBlockInstance.receivedBlockHeader) + fct.worker.AddReceivedHeaderHandler(subroundBlockInstance.receivedFullHeader) fct.consensusCore.Chronology().AddSubround(subroundBlockInstance) return nil @@ -208,12 +215,12 @@ func (fct *factory) generateBlockSubround() error { func (fct *factory) generateSignatureSubround() error { subround, err := spos.NewSubround( - SrBlock, - SrSignature, - SrEndRound, + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, int64(float64(fct.getTimeDuration())*srSignatureStartTime), int64(float64(fct.getTimeDuration())*srSignatureEndTime), - getSubroundName(SrSignature), + bls.GetSubroundName(bls.SrSignature), fct.consensusState, fct.worker.GetConsensusStateChangedChannel(), fct.worker.ExecuteStoredMessages, @@ -236,7 +243,7 @@ func (fct *factory) generateSignatureSubround() error { return err } - fct.worker.AddReceivedMessageCall(MtSignature, subroundSignatureObject.receivedSignature) + fct.worker.AddReceivedMessageCall(bls.MtSignature, subroundSignatureObject.receivedSignature) fct.consensusCore.Chronology().AddSubround(subroundSignatureObject) return nil @@ -244,12 +251,12 @@ func (fct *factory) generateSignatureSubround() error { func (fct *factory) generateEndRoundSubround() error { subround, err := spos.NewSubround( - SrSignature, - SrEndRound, + bls.SrSignature, + bls.SrEndRound, -1, int64(float64(fct.getTimeDuration())*srEndStartTime), int64(float64(fct.getTimeDuration())*srEndEndTime), - getSubroundName(SrEndRound), + bls.GetSubroundName(bls.SrEndRound), fct.consensusState, fct.worker.GetConsensusStateChangedChannel(), fct.worker.ExecuteStoredMessages, @@ -274,8 +281,8 @@ func (fct *factory) generateEndRoundSubround() error { return err } - fct.worker.AddReceivedMessageCall(MtBlockHeaderFinalInfo, subroundEndRoundObject.receivedBlockHeaderFinalInfo) - fct.worker.AddReceivedMessageCall(MtInvalidSigners, subroundEndRoundObject.receivedInvalidSignersInfo) + fct.worker.AddReceivedMessageCall(bls.MtBlockHeaderFinalInfo, subroundEndRoundObject.receivedBlockHeaderFinalInfo) + fct.worker.AddReceivedMessageCall(bls.MtInvalidSigners, subroundEndRoundObject.receivedInvalidSignersInfo) fct.worker.AddReceivedHeaderHandler(subroundEndRoundObject.receivedHeader) fct.consensusCore.Chronology().AddSubround(subroundEndRoundObject) @@ -285,10 +292,10 @@ func (fct *factory) generateEndRoundSubround() error { func (fct *factory) initConsensusThreshold() { pBFTThreshold := core.GetPBFTThreshold(fct.consensusState.ConsensusGroupSize()) pBFTFallbackThreshold := core.GetPBFTFallbackThreshold(fct.consensusState.ConsensusGroupSize()) - fct.consensusState.SetThreshold(SrBlock, 1) - fct.consensusState.SetThreshold(SrSignature, pBFTThreshold) - fct.consensusState.SetFallbackThreshold(SrBlock, 1) - fct.consensusState.SetFallbackThreshold(SrSignature, pBFTFallbackThreshold) + fct.consensusState.SetThreshold(bls.SrBlock, 1) + fct.consensusState.SetThreshold(bls.SrSignature, pBFTThreshold) + fct.consensusState.SetFallbackThreshold(bls.SrBlock, 1) + fct.consensusState.SetFallbackThreshold(bls.SrSignature, pBFTFallbackThreshold) } // IsInterfaceNil returns true if there is no value under the interface diff --git a/consensus/spos/bls/blsSubroundsFactory_test.go b/consensus/spos/bls/v1/blsSubroundsFactory_test.go similarity index 74% rename from consensus/spos/bls/blsSubroundsFactory_test.go rename to consensus/spos/bls/v1/blsSubroundsFactory_test.go index af3267a78cc..897e83d5593 100644 --- a/consensus/spos/bls/blsSubroundsFactory_test.go +++ b/consensus/spos/bls/v1/blsSubroundsFactory_test.go @@ -1,4 +1,4 @@ -package bls_test +package v1_test import ( "context" @@ -8,15 +8,18 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus" - "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v1 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v1" "github.com/multiversx/mx-chain-go/outport" "github.com/multiversx/mx-chain-go/testscommon" + consensusMock "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" testscommonOutport "github.com/multiversx/mx-chain-go/testscommon/outport" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) var chainID = []byte("chain ID") @@ -40,8 +43,8 @@ func executeStoredMessages() { func resetConsensusMessages() { } -func initRoundHandlerMock() *mock.RoundHandlerMock { - return &mock.RoundHandlerMock{ +func initRoundHandlerMock() *consensusMock.RoundHandlerMock { + return &consensusMock.RoundHandlerMock{ RoundIndex: 0, TimeStampCalled: func() time.Time { return time.Unix(0, 0) @@ -53,7 +56,7 @@ func initRoundHandlerMock() *mock.RoundHandlerMock { } func initWorker() spos.WorkerHandler { - sposWorker := &mock.SposWorkerMock{} + sposWorker := &consensusMock.SposWorkerMock{} sposWorker.GetConsensusStateChangedChannelsCalled = func() chan bool { return make(chan bool) } @@ -66,11 +69,11 @@ func initWorker() spos.WorkerHandler { return sposWorker } -func initFactoryWithContainer(container *mock.ConsensusCoreMock) bls.Factory { +func initFactoryWithContainer(container *spos.ConsensusCore) v1.Factory { worker := initWorker() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() - fct, _ := bls.NewSubroundsFactory( + fct, _ := v1.NewSubroundsFactory( container, consensusState, worker, @@ -78,13 +81,14 @@ func initFactoryWithContainer(container *mock.ConsensusCoreMock) bls.Factory { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) return fct } -func initFactory() bls.Factory { - container := mock.InitConsensusCore() +func initFactory() v1.Factory { + container := consensusMock.InitConsensusCore() return initFactoryWithContainer(container) } @@ -116,10 +120,10 @@ func TestFactory_GetMessageTypeName(t *testing.T) { func TestFactory_NewFactoryNilContainerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() worker := initWorker() - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( nil, consensusState, worker, @@ -127,6 +131,7 @@ func TestFactory_NewFactoryNilContainerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -136,10 +141,10 @@ func TestFactory_NewFactoryNilContainerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilConsensusStateShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() worker := initWorker() - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, nil, worker, @@ -147,6 +152,7 @@ func TestFactory_NewFactoryNilConsensusStateShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -156,12 +162,12 @@ func TestFactory_NewFactoryNilConsensusStateShouldFail(t *testing.T) { func TestFactory_NewFactoryNilBlockchainShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetBlockchain(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -169,6 +175,7 @@ func TestFactory_NewFactoryNilBlockchainShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -178,12 +185,12 @@ func TestFactory_NewFactoryNilBlockchainShouldFail(t *testing.T) { func TestFactory_NewFactoryNilBlockProcessorShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetBlockProcessor(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -191,6 +198,7 @@ func TestFactory_NewFactoryNilBlockProcessorShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -200,12 +208,12 @@ func TestFactory_NewFactoryNilBlockProcessorShouldFail(t *testing.T) { func TestFactory_NewFactoryNilBootstrapperShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetBootStrapper(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -213,6 +221,7 @@ func TestFactory_NewFactoryNilBootstrapperShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -222,12 +231,12 @@ func TestFactory_NewFactoryNilBootstrapperShouldFail(t *testing.T) { func TestFactory_NewFactoryNilChronologyHandlerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetChronology(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -235,6 +244,7 @@ func TestFactory_NewFactoryNilChronologyHandlerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -244,12 +254,12 @@ func TestFactory_NewFactoryNilChronologyHandlerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilHasherShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetHasher(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -257,6 +267,7 @@ func TestFactory_NewFactoryNilHasherShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -266,12 +277,12 @@ func TestFactory_NewFactoryNilHasherShouldFail(t *testing.T) { func TestFactory_NewFactoryNilMarshalizerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetMarshalizer(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -279,6 +290,7 @@ func TestFactory_NewFactoryNilMarshalizerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -288,12 +300,12 @@ func TestFactory_NewFactoryNilMarshalizerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilMultiSignerContainerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetMultiSignerContainer(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -301,6 +313,7 @@ func TestFactory_NewFactoryNilMultiSignerContainerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -310,12 +323,12 @@ func TestFactory_NewFactoryNilMultiSignerContainerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilRoundHandlerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetRoundHandler(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -323,6 +336,7 @@ func TestFactory_NewFactoryNilRoundHandlerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -332,12 +346,12 @@ func TestFactory_NewFactoryNilRoundHandlerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilShardCoordinatorShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetShardCoordinator(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -345,6 +359,7 @@ func TestFactory_NewFactoryNilShardCoordinatorShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -354,12 +369,12 @@ func TestFactory_NewFactoryNilShardCoordinatorShouldFail(t *testing.T) { func TestFactory_NewFactoryNilSyncTimerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() container.SetSyncTimer(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -367,6 +382,7 @@ func TestFactory_NewFactoryNilSyncTimerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -376,12 +392,12 @@ func TestFactory_NewFactoryNilSyncTimerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilValidatorGroupSelectorShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() - container.SetValidatorGroupSelector(nil) + container.SetNodesCoordinator(nil) - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -389,6 +405,7 @@ func TestFactory_NewFactoryNilValidatorGroupSelectorShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -398,10 +415,10 @@ func TestFactory_NewFactoryNilValidatorGroupSelectorShouldFail(t *testing.T) { func TestFactory_NewFactoryNilWorkerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, nil, @@ -409,6 +426,7 @@ func TestFactory_NewFactoryNilWorkerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -418,11 +436,11 @@ func TestFactory_NewFactoryNilWorkerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilAppStatusHandlerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -430,6 +448,7 @@ func TestFactory_NewFactoryNilAppStatusHandlerShouldFail(t *testing.T) { currentPid, nil, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -439,11 +458,11 @@ func TestFactory_NewFactoryNilAppStatusHandlerShouldFail(t *testing.T) { func TestFactory_NewFactoryNilSignaturesTrackerShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -451,10 +470,11 @@ func TestFactory_NewFactoryNilSignaturesTrackerShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, nil, + nil, ) assert.Nil(t, fct) - assert.Equal(t, bls.ErrNilSentSignatureTracker, err) + assert.Equal(t, v1.ErrNilSentSignatureTracker, err) } func TestFactory_NewFactoryShouldWork(t *testing.T) { @@ -468,11 +488,11 @@ func TestFactory_NewFactoryShouldWork(t *testing.T) { func TestFactory_NewFactoryEmptyChainIDShouldFail(t *testing.T) { t.Parallel() - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMock.InitConsensusCore() worker := initWorker() - fct, err := bls.NewSubroundsFactory( + fct, err := v1.NewSubroundsFactory( container, consensusState, worker, @@ -480,6 +500,7 @@ func TestFactory_NewFactoryEmptyChainIDShouldFail(t *testing.T) { currentPid, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, + nil, ) assert.Nil(t, fct) @@ -490,7 +511,7 @@ func TestFactory_GenerateSubroundStartRoundShouldFailWhenNewSubroundFail(t *test t.Parallel() fct := *initFactory() - fct.Worker().(*mock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + fct.Worker().(*consensusMock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { return nil } @@ -502,7 +523,7 @@ func TestFactory_GenerateSubroundStartRoundShouldFailWhenNewSubroundFail(t *test func TestFactory_GenerateSubroundStartRoundShouldFailWhenNewSubroundStartRoundFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() fct := *initFactoryWithContainer(container) container.SetSyncTimer(nil) @@ -515,7 +536,7 @@ func TestFactory_GenerateSubroundBlockShouldFailWhenNewSubroundFail(t *testing.T t.Parallel() fct := *initFactory() - fct.Worker().(*mock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + fct.Worker().(*consensusMock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { return nil } @@ -527,7 +548,7 @@ func TestFactory_GenerateSubroundBlockShouldFailWhenNewSubroundFail(t *testing.T func TestFactory_GenerateSubroundBlockShouldFailWhenNewSubroundBlockFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() fct := *initFactoryWithContainer(container) container.SetSyncTimer(nil) @@ -540,7 +561,7 @@ func TestFactory_GenerateSubroundSignatureShouldFailWhenNewSubroundFail(t *testi t.Parallel() fct := *initFactory() - fct.Worker().(*mock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + fct.Worker().(*consensusMock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { return nil } @@ -552,7 +573,7 @@ func TestFactory_GenerateSubroundSignatureShouldFailWhenNewSubroundFail(t *testi func TestFactory_GenerateSubroundSignatureShouldFailWhenNewSubroundSignatureFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() fct := *initFactoryWithContainer(container) container.SetSyncTimer(nil) @@ -565,7 +586,7 @@ func TestFactory_GenerateSubroundEndRoundShouldFailWhenNewSubroundFail(t *testin t.Parallel() fct := *initFactory() - fct.Worker().(*mock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + fct.Worker().(*consensusMock.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { return nil } @@ -577,7 +598,7 @@ func TestFactory_GenerateSubroundEndRoundShouldFailWhenNewSubroundFail(t *testin func TestFactory_GenerateSubroundEndRoundShouldFailWhenNewSubroundEndRoundFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() fct := *initFactoryWithContainer(container) container.SetSyncTimer(nil) @@ -591,16 +612,16 @@ func TestFactory_GenerateSubroundsShouldWork(t *testing.T) { subroundHandlers := 0 - chrm := &mock.ChronologyHandlerMock{} + chrm := &consensusMock.ChronologyHandlerMock{} chrm.AddSubroundCalled = func(subroundHandler consensus.SubroundHandler) { subroundHandlers++ } - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() container.SetChronology(chrm) fct := *initFactoryWithContainer(container) fct.SetOutportHandler(&testscommonOutport.OutportStub{}) - err := fct.GenerateSubrounds() + err := fct.GenerateSubrounds(0) assert.Nil(t, err) assert.Equal(t, 4, subroundHandlers) @@ -609,17 +630,17 @@ func TestFactory_GenerateSubroundsShouldWork(t *testing.T) { func TestFactory_GenerateSubroundsNilOutportShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() fct := *initFactoryWithContainer(container) - err := fct.GenerateSubrounds() + err := fct.GenerateSubrounds(0) assert.Equal(t, outport.ErrNilDriver, err) } func TestFactory_SetIndexerShouldWork(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() fct := *initFactoryWithContainer(container) outportHandler := &testscommonOutport.OutportStub{} diff --git a/consensus/spos/bls/v1/constants.go b/consensus/spos/bls/v1/constants.go new file mode 100644 index 00000000000..fc35333e15a --- /dev/null +++ b/consensus/spos/bls/v1/constants.go @@ -0,0 +1,37 @@ +package v1 + +import ( + logger "github.com/multiversx/mx-chain-logger-go" +) + +var log = logger.GetOrCreate("consensus/spos/bls/v1") + +// waitingAllSigsMaxTimeThreshold specifies the max allocated time for waiting all signatures from the total time of the subround signature +const waitingAllSigsMaxTimeThreshold = 0.5 + +// processingThresholdPercent specifies the max allocated time for processing the block as a percentage of the total time of the round +const processingThresholdPercent = 85 + +// srStartStartTime specifies the start time, from the total time of the round, of Subround Start +const srStartStartTime = 0.0 + +// srEndStartTime specifies the end time, from the total time of the round, of Subround Start +const srStartEndTime = 0.05 + +// srBlockStartTime specifies the start time, from the total time of the round, of Subround Block +const srBlockStartTime = 0.05 + +// srBlockEndTime specifies the end time, from the total time of the round, of Subround Block +const srBlockEndTime = 0.25 + +// srSignatureStartTime specifies the start time, from the total time of the round, of Subround Signature +const srSignatureStartTime = 0.25 + +// srSignatureEndTime specifies the end time, from the total time of the round, of Subround Signature +const srSignatureEndTime = 0.85 + +// srEndStartTime specifies the start time, from the total time of the round, of Subround End +const srEndStartTime = 0.85 + +// srEndEndTime specifies the end time, from the total time of the round, of Subround End +const srEndEndTime = 0.95 diff --git a/consensus/spos/bls/v1/errors.go b/consensus/spos/bls/v1/errors.go new file mode 100644 index 00000000000..b49c581419d --- /dev/null +++ b/consensus/spos/bls/v1/errors.go @@ -0,0 +1,9 @@ +package v1 + +import "errors" + +// ErrNilSentSignatureTracker defines the error for setting a nil SentSignatureTracker +var ErrNilSentSignatureTracker = errors.New("nil sent signature tracker") + +// ErrAndromedaFlagEnabledWithConsensusV1 defines the error for running andromeda enabled under v1 consensus +var ErrAndromedaFlagEnabledWithConsensusV1 = errors.New("andromeda flag enabled with consensus v1") diff --git a/consensus/spos/bls/export_test.go b/consensus/spos/bls/v1/export_test.go similarity index 94% rename from consensus/spos/bls/export_test.go rename to consensus/spos/bls/v1/export_test.go index 71d3cfc8348..4a386a57933 100644 --- a/consensus/spos/bls/export_test.go +++ b/consensus/spos/bls/v1/export_test.go @@ -1,4 +1,4 @@ -package bls +package v1 import ( "context" @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" @@ -18,9 +19,8 @@ import ( "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" ) +// ProcessingThresholdPercent exports the internal processingThresholdPercent const ProcessingThresholdPercent = processingThresholdPercent -const DefaultMaxNumOfMessageTypeAccepted = defaultMaxNumOfMessageTypeAccepted -const MaxNumOfMessageTypeSignatureAccepted = maxNumOfMessageTypeSignatureAccepted // factory @@ -48,7 +48,7 @@ func (fct *factory) ChronologyHandler() consensus.ChronologyHandler { } // ConsensusState gets the consensus state struct pointer -func (fct *factory) ConsensusState() *spos.ConsensusState { +func (fct *factory) ConsensusState() spos.ConsensusStateHandler { return fct.consensusState } @@ -129,8 +129,8 @@ func (fct *factory) Outport() outport.OutportHandler { // subroundStartRound -// SubroundStartRound defines a type for the subroundStartRound structure -type SubroundStartRound *subroundStartRound +// SubroundStartRound defines an alias to the subroundStartRound structure +type SubroundStartRound = *subroundStartRound // DoStartRoundJob method does the job of the subround StartRound func (sr *subroundStartRound) DoStartRoundJob() bool { @@ -160,7 +160,7 @@ func (sr *subroundStartRound) GetSentSignatureTracker() spos.SentSignaturesTrack // subroundBlock // SubroundBlock defines a type for the subroundBlock structure -type SubroundBlock *subroundBlock +type SubroundBlock = *subroundBlock // Blockchain gets the ChainHandler stored in the ConsensusCore func (sr *subroundBlock) BlockChain() data.ChainHandler { @@ -229,8 +229,8 @@ func (sr *subroundBlock) ReceivedBlockBodyAndHeader(cnsDta *consensus.Message) b // subroundSignature -// SubroundSignature defines a type for the subroundSignature structure -type SubroundSignature *subroundSignature +// SubroundSignature defines an alias for the subroundSignature structure +type SubroundSignature = *subroundSignature // DoSignatureJob method does the job of the subround Signature func (sr *subroundSignature) DoSignatureJob() bool { @@ -254,8 +254,8 @@ func (sr *subroundSignature) AreSignaturesCollected(threshold int) (bool, int) { // subroundEndRound -// SubroundEndRound defines a type for the subroundEndRound structure -type SubroundEndRound *subroundEndRound +// SubroundEndRound defines an alias for the subroundEndRound structure +type SubroundEndRound = *subroundEndRound // DoEndRoundJob method does the job of the subround EndRound func (sr *subroundEndRound) DoEndRoundJob() bool { @@ -351,8 +351,3 @@ func (sr *subroundEndRound) GetFullMessagesForInvalidSigners(invalidPubKeys []st func (sr *subroundEndRound) GetSentSignatureTracker() spos.SentSignaturesTracker { return sr.sentSignatureTracker } - -// GetStringValue calls the unexported getStringValue function -func GetStringValue(messageType consensus.MessageType) string { - return getStringValue(messageType) -} diff --git a/consensus/spos/bls/subroundBlock.go b/consensus/spos/bls/v1/subroundBlock.go similarity index 84% rename from consensus/spos/bls/subroundBlock.go rename to consensus/spos/bls/v1/subroundBlock.go index a83969721b8..bf6a0bf56c7 100644 --- a/consensus/spos/bls/subroundBlock.go +++ b/consensus/spos/bls/v1/subroundBlock.go @@ -1,15 +1,18 @@ -package bls +package v1 import ( + "bytes" "context" "time" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" ) // maxAllowedSizeInBytes defines how many bytes are allowed as payload in a message @@ -52,7 +55,7 @@ func checkNewSubroundBlockParams( return spos.ErrNilSubround } - if baseSubround.ConsensusState == nil { + if check.IfNil(baseSubround.ConsensusStateHandler) { return spos.ErrNilConsensusState } @@ -114,7 +117,7 @@ func (sr *subroundBlock) doBlockJob(ctx context.Context) bool { // placeholder for subroundBlock.doBlockJob script - sr.ConsensusCoreHandler.ScheduledProcessor().StartScheduledProcessing(header, body, sr.RoundTimeStamp) + sr.ConsensusCoreHandler.ScheduledProcessor().StartScheduledProcessing(header, body, sr.GetRoundTimeStamp()) return true } @@ -163,7 +166,7 @@ func (sr *subroundBlock) couldBeSentTogether(marshalizedBody []byte, marshalized } func (sr *subroundBlock) createBlock(header data.HeaderHandler) (data.HeaderHandler, data.BodyHandler, error) { - startTime := sr.RoundTimeStamp + startTime := sr.GetRoundTimeStamp() maxTime := time.Duration(sr.EndTime()) haveTimeInCurrentSubround := func() bool { return sr.RoundHandler().RemainingTime(startTime, maxTime) > 0 @@ -202,7 +205,7 @@ func (sr *subroundBlock) sendHeaderAndBlockBody( marshalizedHeader, []byte(leader), nil, - int(MtBlockBodyAndHeader), + int(bls.MtBlockBodyAndHeader), sr.RoundHandler().Index(), sr.ChainID(), nil, @@ -222,9 +225,9 @@ func (sr *subroundBlock) sendHeaderAndBlockBody( "nonce", headerHandler.GetNonce(), "hash", headerHash) - sr.Data = headerHash - sr.Body = bodyHandler - sr.Header = headerHandler + sr.SetData(headerHash) + sr.SetBody(bodyHandler) + sr.SetHeader(headerHandler) return true } @@ -244,7 +247,7 @@ func (sr *subroundBlock) sendBlockBody(bodyHandler data.BodyHandler, marshalized nil, []byte(leader), nil, - int(MtBlockBody), + int(bls.MtBlockBody), sr.RoundHandler().Index(), sr.ChainID(), nil, @@ -262,7 +265,7 @@ func (sr *subroundBlock) sendBlockBody(bodyHandler data.BodyHandler, marshalized log.Debug("step 1: block body has been sent") - sr.Body = bodyHandler + sr.SetBody(bodyHandler) return true } @@ -284,7 +287,7 @@ func (sr *subroundBlock) sendBlockHeader(headerHandler data.HeaderHandler, marsh marshalizedHeader, []byte(leader), nil, - int(MtBlockHeader), + int(bls.MtBlockHeader), sr.RoundHandler().Index(), sr.ChainID(), nil, @@ -304,8 +307,8 @@ func (sr *subroundBlock) sendBlockHeader(headerHandler data.HeaderHandler, marsh "nonce", headerHandler.GetNonce(), "hash", headerHash) - sr.Data = headerHash - sr.Header = headerHandler + sr.SetData(headerHash) + sr.SetHeader(headerHandler) return true } @@ -332,6 +335,10 @@ func (sr *subroundBlock) createHeader() (data.HeaderHandler, error) { return nil, err } + if sr.EnableEpochsHandler().IsFlagEnabledInEpoch(common.AndromedaFlag, hdr.GetEpoch()) { + return nil, ErrAndromedaFlagEnabledWithConsensusV1 + } + err = hdr.SetPrevHash(prevHash) if err != nil { return nil, err @@ -413,17 +420,22 @@ func (sr *subroundBlock) receivedBlockBodyAndHeader(ctx context.Context, cnsDta return false } - sr.Data = cnsDta.BlockHeaderHash - sr.Body = sr.BlockProcessor().DecodeBlockBody(cnsDta.Body) - sr.Header = sr.BlockProcessor().DecodeBlockHeader(cnsDta.Header) + header := sr.BlockProcessor().DecodeBlockHeader(cnsDta.Header) + if sr.isFlagActiveForHeader(header) { + return false + } + + sr.SetData(cnsDta.BlockHeaderHash) + sr.SetBody(sr.BlockProcessor().DecodeBlockBody(cnsDta.Body)) + sr.SetHeader(header) - isInvalidData := check.IfNil(sr.Body) || sr.isInvalidHeaderOrData() + isInvalidData := check.IfNil(sr.GetBody()) || sr.isInvalidHeaderOrData() if isInvalidData { return false } log.Debug("step 1: block body and header have been received", - "nonce", sr.Header.GetNonce(), + "nonce", sr.GetHeader().GetNonce(), "hash", cnsDta.BlockHeaderHash) sw.Start("processReceivedBlock") @@ -440,7 +452,7 @@ func (sr *subroundBlock) receivedBlockBodyAndHeader(ctx context.Context, cnsDta } func (sr *subroundBlock) isInvalidHeaderOrData() bool { - return sr.Data == nil || check.IfNil(sr.Header) || sr.Header.CheckFieldsForNil() != nil + return sr.GetData() == nil || check.IfNil(sr.GetHeader()) || sr.GetHeader().CheckFieldsForNil() != nil } // receivedBlockBody method is called when a block body is received through the block body channel @@ -465,9 +477,9 @@ func (sr *subroundBlock) receivedBlockBody(ctx context.Context, cnsDta *consensu return false } - sr.Body = sr.BlockProcessor().DecodeBlockBody(cnsDta.Body) + sr.SetBody(sr.BlockProcessor().DecodeBlockBody(cnsDta.Body)) - if check.IfNil(sr.Body) { + if check.IfNil(sr.GetBody()) { return false } @@ -484,6 +496,27 @@ func (sr *subroundBlock) receivedBlockBody(ctx context.Context, cnsDta *consensu return blockProcessedWithSuccess } +func (sr *subroundBlock) receivedFullHeader(headerHandler data.HeaderHandler) { + if sr.ShardCoordinator().SelfId() != headerHandler.GetShardID() { + log.Debug("subroundBlock.ReceivedFullHeader early exit", "headerShardID", headerHandler.GetShardID(), "selfShardID", sr.ShardCoordinator().SelfId()) + return + } + + if !sr.EnableEpochsHandler().IsFlagEnabledInEpoch(common.AndromedaFlag, headerHandler.GetEpoch()) { + log.Debug("subroundBlock.ReceivedFullHeader early exit", "flagNotEnabled in header epoch", headerHandler.GetEpoch()) + return + } + + log.Debug("subroundBlock.ReceivedFullHeader", "nonce", headerHandler.GetNonce(), "epoch", headerHandler.GetEpoch()) + + lastCommittedBlockHash := sr.Blockchain().GetCurrentBlockHeaderHash() + if bytes.Equal(lastCommittedBlockHash, headerHandler.GetPrevHash()) { + // Need to switch to consensus v2 + log.Debug("subroundBlock.ReceivedFullHeader switching epoch") + go sr.EpochNotifier().CheckEpoch(headerHandler) + } +} + // receivedBlockHeader method is called when a block header is received through the block header channel. // If the block header is valid, then the validatorRoundStates map corresponding to the node which sent it, // is set on true for the subround Block @@ -512,15 +545,20 @@ func (sr *subroundBlock) receivedBlockHeader(ctx context.Context, cnsDta *consen return false } - sr.Data = cnsDta.BlockHeaderHash - sr.Header = sr.BlockProcessor().DecodeBlockHeader(cnsDta.Header) + header := sr.BlockProcessor().DecodeBlockHeader(cnsDta.Header) + if sr.isFlagActiveForHeader(header) { + return false + } + + sr.SetData(cnsDta.BlockHeaderHash) + sr.SetHeader(header) if sr.isInvalidHeaderOrData() { return false } log.Debug("step 1: block header has been received", - "nonce", sr.Header.GetNonce(), + "nonce", sr.GetHeader().GetNonce(), "hash", cnsDta.BlockHeaderHash) blockProcessedWithSuccess := sr.processReceivedBlock(ctx, cnsDta) @@ -533,11 +571,18 @@ func (sr *subroundBlock) receivedBlockHeader(ctx context.Context, cnsDta *consen return blockProcessedWithSuccess } +func (sr *subroundBlock) isFlagActiveForHeader(headerHandler data.HeaderHandler) bool { + if check.IfNil(headerHandler) { + return false + } + return sr.EnableEpochsHandler().IsFlagEnabledInEpoch(common.AndromedaFlag, headerHandler.GetEpoch()) +} + func (sr *subroundBlock) processReceivedBlock(ctx context.Context, cnsDta *consensus.Message) bool { - if check.IfNil(sr.Body) { + if check.IfNil(sr.GetBody()) { return false } - if check.IfNil(sr.Header) { + if check.IfNil(sr.GetHeader()) { return false } @@ -547,20 +592,20 @@ func (sr *subroundBlock) processReceivedBlock(ctx context.Context, cnsDta *conse sr.SetProcessingBlock(true) - shouldNotProcessBlock := sr.ExtendedCalled || cnsDta.RoundIndex < sr.RoundHandler().Index() + shouldNotProcessBlock := sr.GetExtendedCalled() || cnsDta.RoundIndex < sr.RoundHandler().Index() if shouldNotProcessBlock { log.Debug("canceled round, extended has been called or round index has been changed", "round", sr.RoundHandler().Index(), "subround", sr.Name(), "cnsDta round", cnsDta.RoundIndex, - "extended called", sr.ExtendedCalled, + "extended called", sr.GetExtendedCalled(), ) return false } node := string(cnsDta.PubKey) - startTime := sr.RoundTimeStamp + startTime := sr.GetRoundTimeStamp() maxTime := sr.RoundHandler().TimeDuration() * time.Duration(sr.processingThresholdPercentage) / 100 remainingTimeInCurrentRound := func() time.Duration { return sr.RoundHandler().RemainingTime(startTime, maxTime) @@ -570,8 +615,8 @@ func (sr *subroundBlock) processReceivedBlock(ctx context.Context, cnsDta *conse defer sr.computeSubroundProcessingMetric(metricStatTime, common.MetricProcessedProposedBlock) err := sr.BlockProcessor().ProcessBlock( - sr.Header, - sr.Body, + sr.GetHeader(), + sr.GetBody(), remainingTimeInCurrentRound, ) @@ -586,7 +631,7 @@ func (sr *subroundBlock) processReceivedBlock(ctx context.Context, cnsDta *conse if err != nil { sr.printCancelRoundLogMessage(ctx, err) - sr.RoundCanceled = true + sr.SetRoundCanceled(true) return false } @@ -597,7 +642,7 @@ func (sr *subroundBlock) processReceivedBlock(ctx context.Context, cnsDta *conse return false } - sr.ConsensusCoreHandler.ScheduledProcessor().StartScheduledProcessing(sr.Header, sr.Body, sr.RoundTimeStamp) + sr.ConsensusCoreHandler.ScheduledProcessor().StartScheduledProcessing(sr.GetHeader(), sr.GetBody(), sr.GetRoundTimeStamp()) return true } @@ -627,7 +672,7 @@ func (sr *subroundBlock) computeSubroundProcessingMetric(startTime time.Time, me // doBlockConsensusCheck method checks if the consensus in the subround Block is achieved func (sr *subroundBlock) doBlockConsensusCheck() bool { - if sr.RoundCanceled { + if sr.GetRoundCanceled() { return false } diff --git a/consensus/spos/bls/subroundBlock_test.go b/consensus/spos/bls/v1/subroundBlock_test.go similarity index 74% rename from consensus/spos/bls/subroundBlock_test.go rename to consensus/spos/bls/v1/subroundBlock_test.go index 2354ab92b11..d54a1b9b792 100644 --- a/consensus/spos/bls/subroundBlock_test.go +++ b/consensus/spos/bls/v1/subroundBlock_test.go @@ -1,4 +1,4 @@ -package bls_test +package v1_test import ( "errors" @@ -10,19 +10,23 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/consensus" - "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v1 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v1" "github.com/multiversx/mx-chain-go/testscommon" + consensusMock "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func defaultSubroundForSRBlock(consensusState *spos.ConsensusState, ch chan bool, - container *mock.ConsensusCoreMock, appStatusHandler core.AppStatusHandler) (*spos.Subround, error) { + container *spos.ConsensusCore, appStatusHandler core.AppStatusHandler) (*spos.Subround, error) { return spos.NewSubround( bls.SrStartRound, bls.SrBlock, @@ -55,21 +59,21 @@ func createDefaultHeader() *block.Header { } } -func defaultSubroundBlockFromSubround(sr *spos.Subround) (bls.SubroundBlock, error) { - srBlock, err := bls.NewSubroundBlock( +func defaultSubroundBlockFromSubround(sr *spos.Subround) (v1.SubroundBlock, error) { + srBlock, err := v1.NewSubroundBlock( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, ) return srBlock, err } -func defaultSubroundBlockWithoutErrorFromSubround(sr *spos.Subround) bls.SubroundBlock { - srBlock, _ := bls.NewSubroundBlock( +func defaultSubroundBlockWithoutErrorFromSubround(sr *spos.Subround) v1.SubroundBlock { + srBlock, _ := v1.NewSubroundBlock( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, ) return srBlock @@ -77,9 +81,9 @@ func defaultSubroundBlockWithoutErrorFromSubround(sr *spos.Subround) bls.Subroun func initSubroundBlock( blockChain data.ChainHandler, - container *mock.ConsensusCoreMock, + container *spos.ConsensusCore, appStatusHandler core.AppStatusHandler, -) bls.SubroundBlock { +) v1.SubroundBlock { if blockChain == nil { blockChain = &testscommon.ChainHandlerStub{ GetCurrentBlockHeaderCalled: func() data.HeaderHandler { @@ -98,7 +102,7 @@ func initSubroundBlock( } } - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) container.SetBlockchain(blockChain) @@ -108,19 +112,19 @@ func initSubroundBlock( return srBlock } -func createConsensusContainers() []*mock.ConsensusCoreMock { - consensusContainers := make([]*mock.ConsensusCoreMock, 0) - container := mock.InitConsensusCore() +func createConsensusContainers() []*spos.ConsensusCore { + consensusContainers := make([]*spos.ConsensusCore, 0) + container := consensusMock.InitConsensusCore() consensusContainers = append(consensusContainers, container) - container = mock.InitConsensusCoreHeaderV2() + container = consensusMock.InitConsensusCoreHeaderV2() consensusContainers = append(consensusContainers, container) return consensusContainers } func initSubroundBlockWithBlockProcessor( bp *testscommon.BlockProcessorStub, - container *mock.ConsensusCoreMock, -) bls.SubroundBlock { + container *spos.ConsensusCore, +) v1.SubroundBlock { blockChain := &testscommon.ChainHandlerStub{ GetGenesisHeaderCalled: func() data.HeaderHandler { return &block.Header{ @@ -136,7 +140,7 @@ func initSubroundBlockWithBlockProcessor( container.SetBlockchain(blockChain) container.SetBlockProcessor(blockProcessorMock) - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -147,10 +151,10 @@ func initSubroundBlockWithBlockProcessor( func TestSubroundBlock_NewSubroundBlockNilSubroundShouldFail(t *testing.T) { t.Parallel() - srBlock, err := bls.NewSubroundBlock( + srBlock, err := v1.NewSubroundBlock( nil, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, ) assert.Nil(t, srBlock) assert.Equal(t, spos.ErrNilSubround, err) @@ -158,9 +162,9 @@ func TestSubroundBlock_NewSubroundBlockNilSubroundShouldFail(t *testing.T) { func TestSubroundBlock_NewSubroundBlockNilBlockchainShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -174,9 +178,9 @@ func TestSubroundBlock_NewSubroundBlockNilBlockchainShouldFail(t *testing.T) { func TestSubroundBlock_NewSubroundBlockNilBlockProcessorShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -190,12 +194,12 @@ func TestSubroundBlock_NewSubroundBlockNilBlockProcessorShouldFail(t *testing.T) func TestSubroundBlock_NewSubroundBlockNilConsensusStateShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMock.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) - sr.ConsensusState = nil + sr.ConsensusStateHandler = nil srBlock, err := defaultSubroundBlockFromSubround(sr) assert.Nil(t, srBlock) @@ -204,9 +208,9 @@ func TestSubroundBlock_NewSubroundBlockNilConsensusStateShouldFail(t *testing.T) func TestSubroundBlock_NewSubroundBlockNilHasherShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -219,9 +223,9 @@ func TestSubroundBlock_NewSubroundBlockNilHasherShouldFail(t *testing.T) { func TestSubroundBlock_NewSubroundBlockNilMarshalizerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -234,9 +238,9 @@ func TestSubroundBlock_NewSubroundBlockNilMarshalizerShouldFail(t *testing.T) { func TestSubroundBlock_NewSubroundBlockNilMultiSignerContainerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -249,9 +253,9 @@ func TestSubroundBlock_NewSubroundBlockNilMultiSignerContainerShouldFail(t *test func TestSubroundBlock_NewSubroundBlockNilRoundHandlerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -264,9 +268,9 @@ func TestSubroundBlock_NewSubroundBlockNilRoundHandlerShouldFail(t *testing.T) { func TestSubroundBlock_NewSubroundBlockNilShardCoordinatorShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -279,9 +283,9 @@ func TestSubroundBlock_NewSubroundBlockNilShardCoordinatorShouldFail(t *testing. func TestSubroundBlock_NewSubroundBlockNilSyncTimerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) @@ -294,9 +298,9 @@ func TestSubroundBlock_NewSubroundBlockNilSyncTimerShouldFail(t *testing.T) { func TestSubroundBlock_NewSubroundBlockShouldWork(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) srBlock, err := defaultSubroundBlockFromSubround(sr) @@ -306,12 +310,12 @@ func TestSubroundBlock_NewSubroundBlockShouldWork(t *testing.T) { func TestSubroundBlock_DoBlockJob(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) r := sr.DoBlockJob() assert.False(t, r) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + sr.SetSelfPubKey(sr.Leader()) _ = sr.SetJobDone(sr.SelfPubKey(), bls.SrBlock, true) r = sr.DoBlockJob() assert.False(t, r) @@ -331,34 +335,34 @@ func TestSubroundBlock_DoBlockJob(t *testing.T) { r = sr.DoBlockJob() assert.False(t, r) - bpm = mock.InitBlockProcessorMock(container.Marshalizer()) + bpm = consensusMock.InitBlockProcessorMock(container.Marshalizer()) container.SetBlockProcessor(bpm) - bm := &mock.BroadcastMessengerMock{ + bm := &consensusMock.BroadcastMessengerMock{ BroadcastConsensusMessageCalled: func(message *consensus.Message) error { return nil }, } container.SetBroadcastMessenger(bm) - container.SetRoundHandler(&mock.RoundHandlerMock{ + container.SetRoundHandler(&consensusMock.RoundHandlerMock{ RoundIndex: 1, }) r = sr.DoBlockJob() assert.True(t, r) - assert.Equal(t, uint64(1), sr.Header.GetNonce()) + assert.Equal(t, uint64(1), sr.GetHeader().GetNonce()) } func TestSubroundBlock_ReceivedBlockBodyAndHeaderDataAlreadySet(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{Nonce: 1} blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.Leader()), bls.MtBlockBodyAndHeader) - sr.Data = []byte("some data") + sr.SetData([]byte("some data")) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) } @@ -366,15 +370,15 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderDataAlreadySet(t *testing.T) { func TestSubroundBlock_ReceivedBlockBodyAndHeaderNodeNotLeaderInCurrentRound(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{Nonce: 1} blkBody := &block.Body{} cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[1]), bls.MtBlockBodyAndHeader) - sr.Data = nil + sr.SetData(nil) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) } @@ -382,16 +386,16 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderNodeNotLeaderInCurrentRound(t * func TestSubroundBlock_ReceivedBlockBodyAndHeaderCannotProcessJobDone(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{Nonce: 1} blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.Leader()), bls.MtBlockBodyAndHeader) - sr.Data = nil - _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrBlock, true) + sr.SetData(nil) + _ = sr.SetJobDone(sr.Leader(), bls.SrBlock, true) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) @@ -400,22 +404,22 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderCannotProcessJobDone(t *testing func TestSubroundBlock_ReceivedBlockBodyAndHeaderErrorDecoding(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - blProc := mock.InitBlockProcessorMock(container.Marshalizer()) + container := consensusMock.InitConsensusCore() + blProc := consensusMock.InitBlockProcessorMock(container.Marshalizer()) blProc.DecodeBlockHeaderCalled = func(dta []byte) data.HeaderHandler { // error decoding so return nil return nil } container.SetBlockProcessor(blProc) - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{Nonce: 1} blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.Leader()), bls.MtBlockBodyAndHeader) - sr.Data = nil + sr.SetData(nil) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) @@ -424,16 +428,16 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderErrorDecoding(t *testing.T) { func TestSubroundBlock_ReceivedBlockBodyAndHeaderBodyAlreadyReceived(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{Nonce: 1} blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.Leader()), bls.MtBlockBodyAndHeader) - sr.Data = nil - sr.Body = &block.Body{} + sr.SetData(nil) + sr.SetBody(&block.Body{}) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) @@ -442,16 +446,16 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderBodyAlreadyReceived(t *testing. func TestSubroundBlock_ReceivedBlockBodyAndHeaderHeaderAlreadyReceived(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{Nonce: 1} blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.Leader()), bls.MtBlockBodyAndHeader) - sr.Data = nil - sr.Header = &block.Header{Nonce: 1} + sr.SetData(nil) + sr.SetHeader(&block.Header{Nonce: 1}) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) } @@ -459,14 +463,16 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderHeaderAlreadyReceived(t *testin func TestSubroundBlock_ReceivedBlockBodyAndHeaderOK(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) t.Run("block is valid", func(t *testing.T) { hdr := createDefaultHeader() blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) - sr.Data = nil + leader, err := sr.GetLeader() + require.Nil(t, err) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(leader), bls.MtBlockBodyAndHeader) + sr.SetData(nil) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.True(t, r) }) @@ -475,15 +481,17 @@ func TestSubroundBlock_ReceivedBlockBodyAndHeaderOK(t *testing.T) { Nonce: 1, } blkBody := &block.Body{} - cnsMsg := createConsensusMessage(hdr, blkBody, []byte(sr.ConsensusGroup()[0]), bls.MtBlockBodyAndHeader) - sr.Data = nil + leader, err := sr.GetLeader() + require.Nil(t, err) + cnsMsg := createConsensusMessage(hdr, blkBody, []byte(leader), bls.MtBlockBodyAndHeader) + sr.SetData(nil) r := sr.ReceivedBlockBodyAndHeader(cnsMsg) assert.False(t, r) }) } func createConsensusMessage(header *block.Header, body *block.Body, leader []byte, topic consensus.MessageType) *consensus.Message { - marshaller := &mock.MarshalizerMock{} + marshaller := &marshallerMock.MarshalizerMock{} hasher := &hashingMocks.HasherMock{} hdrStr, _ := marshaller.Marshal(header) @@ -510,17 +518,19 @@ func createConsensusMessage(header *block.Header, body *block.Body, leader []byt func TestSubroundBlock_ReceivedBlock(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) - blockProcessorMock := mock.InitBlockProcessorMock(container.Marshalizer()) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + blockProcessorMock := consensusMock.InitBlockProcessorMock(container.Marshalizer()) blkBody := &block.Body{} - blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + blkBodyStr, _ := marshallerMock.MarshalizerMock{}.Marshal(blkBody) + leader, err := sr.GetLeader() + assert.Nil(t, err) cnsMsg := consensus.NewConsensusMessage( nil, nil, blkBodyStr, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockBody), 0, @@ -531,11 +541,11 @@ func TestSubroundBlock_ReceivedBlock(t *testing.T) { currentPid, nil, ) - sr.Body = &block.Body{} + sr.SetBody(&block.Body{}) r := sr.ReceivedBlockBody(cnsMsg) assert.False(t, r) - sr.Body = nil + sr.SetBody(nil) cnsMsg.PubKey = []byte(sr.ConsensusGroup()[1]) r = sr.ReceivedBlockBody(cnsMsg) assert.False(t, r) @@ -558,7 +568,7 @@ func TestSubroundBlock_ReceivedBlock(t *testing.T) { nil, nil, hdrStr, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockHeader), 0, @@ -572,12 +582,12 @@ func TestSubroundBlock_ReceivedBlock(t *testing.T) { r = sr.ReceivedBlockHeader(cnsMsg) assert.False(t, r) - sr.Data = nil - sr.Header = hdr + sr.SetData(nil) + sr.SetHeader(hdr) r = sr.ReceivedBlockHeader(cnsMsg) assert.False(t, r) - sr.Header = nil + sr.SetHeader(nil) cnsMsg.PubKey = []byte(sr.ConsensusGroup()[1]) r = sr.ReceivedBlockHeader(cnsMsg) assert.False(t, r) @@ -589,11 +599,11 @@ func TestSubroundBlock_ReceivedBlock(t *testing.T) { sr.SetStatus(bls.SrBlock, spos.SsNotFinished) container.SetBlockProcessor(blockProcessorMock) - sr.Data = nil - sr.Header = nil + sr.SetData(nil) + sr.SetHeader(nil) hdr = createDefaultHeader() hdr.Nonce = 1 - hdrStr, _ = mock.MarshalizerMock{}.Marshal(hdr) + hdrStr, _ = marshallerMock.MarshalizerMock{}.Marshal(hdr) hdrHash = (&hashingMocks.HasherMock{}).Compute(string(hdrStr)) cnsMsg.BlockHeaderHash = hdrHash cnsMsg.Header = hdrStr @@ -603,14 +613,15 @@ func TestSubroundBlock_ReceivedBlock(t *testing.T) { func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenBodyAndHeaderAreNotSet(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + leader, _ := sr.GetLeader() cnsMsg := consensus.NewConsensusMessage( nil, nil, nil, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockBodyAndHeader), 0, @@ -626,9 +637,9 @@ func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenBodyAndHeaderAre func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockFails(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) - blProcMock := mock.InitBlockProcessorMock(container.Marshalizer()) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + blProcMock := consensusMock.InitBlockProcessorMock(container.Marshalizer()) err := errors.New("error process block") blProcMock.ProcessBlockCalled = func(data.HeaderHandler, data.BodyHandler, func() time.Duration) error { return err @@ -636,13 +647,14 @@ func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockFail container.SetBlockProcessor(blProcMock) hdr := &block.Header{} blkBody := &block.Body{} - blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + blkBodyStr, _ := marshallerMock.MarshalizerMock{}.Marshal(blkBody) + leader, _ := sr.GetLeader() cnsMsg := consensus.NewConsensusMessage( nil, nil, blkBodyStr, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockBody), 0, @@ -653,24 +665,25 @@ func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockFail currentPid, nil, ) - sr.Header = hdr - sr.Body = blkBody + sr.SetHeader(hdr) + sr.SetBody(blkBody) assert.False(t, sr.ProcessReceivedBlock(cnsMsg)) } func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockReturnsInNextRound(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr := &block.Header{} blkBody := &block.Body{} - blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + blkBodyStr, _ := marshallerMock.MarshalizerMock{}.Marshal(blkBody) + leader, _ := sr.GetLeader() cnsMsg := consensus.NewConsensusMessage( nil, nil, blkBodyStr, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockBody), 0, @@ -681,14 +694,14 @@ func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockRetu currentPid, nil, ) - sr.Header = hdr - sr.Body = blkBody - blockProcessorMock := mock.InitBlockProcessorMock(container.Marshalizer()) + sr.SetHeader(hdr) + sr.SetBody(blkBody) + blockProcessorMock := consensusMock.InitBlockProcessorMock(container.Marshalizer()) blockProcessorMock.ProcessBlockCalled = func(header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration) error { return errors.New("error") } container.SetBlockProcessor(blockProcessorMock) - container.SetRoundHandler(&mock.RoundHandlerMock{RoundIndex: 1}) + container.SetRoundHandler(&consensusMock.RoundHandlerMock{RoundIndex: 1}) assert.False(t, sr.ProcessReceivedBlock(cnsMsg)) } @@ -697,17 +710,18 @@ func TestSubroundBlock_ProcessReceivedBlockShouldReturnTrue(t *testing.T) { consensusContainers := createConsensusContainers() for _, container := range consensusContainers { - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) hdr, _ := container.BlockProcessor().CreateNewHeader(1, 1) hdr, blkBody, _ := container.BlockProcessor().CreateBlock(hdr, func() bool { return true }) - blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + blkBodyStr, _ := marshallerMock.MarshalizerMock{}.Marshal(blkBody) + leader, _ := sr.GetLeader() cnsMsg := consensus.NewConsensusMessage( nil, nil, blkBodyStr, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockBody), 0, @@ -718,19 +732,19 @@ func TestSubroundBlock_ProcessReceivedBlockShouldReturnTrue(t *testing.T) { currentPid, nil, ) - sr.Header = hdr - sr.Body = blkBody + sr.SetHeader(hdr) + sr.SetBody(blkBody) assert.True(t, sr.ProcessReceivedBlock(cnsMsg)) } } func TestSubroundBlock_RemainingTimeShouldReturnNegativeValue(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() roundHandlerMock := initRoundHandlerMock() container.SetRoundHandler(roundHandlerMock) - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) remainingTimeInThisRound := func() time.Duration { roundStartTime := sr.RoundHandler().TimeStamp() currentTime := sr.SyncTimer().CurrentTime() @@ -739,19 +753,19 @@ func TestSubroundBlock_RemainingTimeShouldReturnNegativeValue(t *testing.T) { return remainingTime } - container.SetSyncTimer(&mock.SyncTimerMock{CurrentTimeCalled: func() time.Time { + container.SetSyncTimer(&consensusMock.SyncTimerMock{CurrentTimeCalled: func() time.Time { return time.Unix(0, 0).Add(roundTimeDuration * 84 / 100) }}) ret := remainingTimeInThisRound() assert.True(t, ret > 0) - container.SetSyncTimer(&mock.SyncTimerMock{CurrentTimeCalled: func() time.Time { + container.SetSyncTimer(&consensusMock.SyncTimerMock{CurrentTimeCalled: func() time.Time { return time.Unix(0, 0).Add(roundTimeDuration * 85 / 100) }}) ret = remainingTimeInThisRound() assert.True(t, ret == 0) - container.SetSyncTimer(&mock.SyncTimerMock{CurrentTimeCalled: func() time.Time { + container.SetSyncTimer(&consensusMock.SyncTimerMock{CurrentTimeCalled: func() time.Time { return time.Unix(0, 0).Add(roundTimeDuration * 86 / 100) }}) ret = remainingTimeInThisRound() @@ -760,24 +774,24 @@ func TestSubroundBlock_RemainingTimeShouldReturnNegativeValue(t *testing.T) { func TestSubroundBlock_DoBlockConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) - sr.RoundCanceled = true + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr.SetRoundCanceled(true) assert.False(t, sr.DoBlockConsensusCheck()) } func TestSubroundBlock_DoBlockConsensusCheckShouldReturnTrueWhenSubroundIsFinished(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) sr.SetStatus(bls.SrBlock, spos.SsFinished) assert.True(t, sr.DoBlockConsensusCheck()) } func TestSubroundBlock_DoBlockConsensusCheckShouldReturnTrueWhenBlockIsReceivedReturnTrue(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) for i := 0; i < sr.Threshold(bls.SrBlock); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrBlock, true) } @@ -786,15 +800,15 @@ func TestSubroundBlock_DoBlockConsensusCheckShouldReturnTrueWhenBlockIsReceivedR func TestSubroundBlock_DoBlockConsensusCheckShouldReturnFalseWhenBlockIsReceivedReturnFalse(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) assert.False(t, sr.DoBlockConsensusCheck()) } func TestSubroundBlock_IsBlockReceived(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) for i := 0; i < len(sr.ConsensusGroup()); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrBlock, false) _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, false) @@ -815,8 +829,8 @@ func TestSubroundBlock_IsBlockReceived(t *testing.T) { func TestSubroundBlock_HaveTimeInCurrentSubroundShouldReturnTrue(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) haveTimeInCurrentSubound := func() bool { roundStartTime := sr.RoundHandler().TimeStamp() currentTime := sr.SyncTimer().CurrentTime() @@ -825,14 +839,14 @@ func TestSubroundBlock_HaveTimeInCurrentSubroundShouldReturnTrue(t *testing.T) { return time.Duration(remainingTime) > 0 } - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMock.RoundHandlerMock{} roundHandlerMock.TimeDurationCalled = func() time.Duration { return 4000 * time.Millisecond } roundHandlerMock.TimeStampCalled = func() time.Time { return time.Unix(0, 0) } - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMock.SyncTimerMock{} timeElapsed := sr.EndTime() - 1 syncTimerMock.CurrentTimeCalled = func() time.Time { return time.Unix(0, timeElapsed) @@ -845,8 +859,8 @@ func TestSubroundBlock_HaveTimeInCurrentSubroundShouldReturnTrue(t *testing.T) { func TestSubroundBlock_HaveTimeInCurrentSuboundShouldReturnFalse(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMock.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) haveTimeInCurrentSubound := func() bool { roundStartTime := sr.RoundHandler().TimeStamp() currentTime := sr.SyncTimer().CurrentTime() @@ -855,14 +869,14 @@ func TestSubroundBlock_HaveTimeInCurrentSuboundShouldReturnFalse(t *testing.T) { return time.Duration(remainingTime) > 0 } - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMock.RoundHandlerMock{} roundHandlerMock.TimeDurationCalled = func() time.Duration { return 4000 * time.Millisecond } roundHandlerMock.TimeStampCalled = func() time.Time { return time.Unix(0, 0) } - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMock.SyncTimerMock{} timeElapsed := sr.EndTime() + 1 syncTimerMock.CurrentTimeCalled = func() time.Time { return time.Unix(0, timeElapsed) @@ -892,7 +906,7 @@ func TestSubroundBlock_CreateHeaderNilCurrentHeader(t *testing.T) { consensusContainers := createConsensusContainers() for _, container := range consensusContainers { - sr := *initSubroundBlock(blockChain, container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundBlock(blockChain, container, &statusHandler.AppStatusHandlerStub{}) _ = sr.BlockChain().SetCurrentBlockHeaderAndRootHash(nil, nil) header, _ := sr.CreateHeader() header, body, _ := sr.CreateBlock(header) @@ -923,7 +937,7 @@ func TestSubroundBlock_CreateHeaderNilCurrentHeader(t *testing.T) { func TestSubroundBlock_CreateHeaderNotNilCurrentHeader(t *testing.T) { consensusContainers := createConsensusContainers() for _, container := range consensusContainers { - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) _ = sr.BlockChain().SetCurrentBlockHeaderAndRootHash(&block.Header{ Nonce: 1, }, []byte("root hash")) @@ -967,8 +981,8 @@ func TestSubroundBlock_CreateHeaderMultipleMiniBlocks(t *testing.T) { } }, } - container := mock.InitConsensusCore() - bp := mock.InitBlockProcessorMock(container.Marshalizer()) + container := consensusMock.InitConsensusCore() + bp := consensusMock.InitBlockProcessorMock(container.Marshalizer()) bp.CreateBlockCalled = func(header data.HeaderHandler, haveTime func() bool) (data.HeaderHandler, data.BodyHandler, error) { shardHeader, _ := header.(*block.Header) shardHeader.MiniBlockHeaders = mbHeaders @@ -976,7 +990,7 @@ func TestSubroundBlock_CreateHeaderMultipleMiniBlocks(t *testing.T) { return shardHeader, &block.Body{}, nil } - sr := *initSubroundBlockWithBlockProcessor(bp, container) + sr := initSubroundBlockWithBlockProcessor(bp, container) container.SetBlockchain(&blockChainMock) header, _ := sr.CreateHeader() @@ -1002,12 +1016,12 @@ func TestSubroundBlock_CreateHeaderMultipleMiniBlocks(t *testing.T) { func TestSubroundBlock_CreateHeaderNilMiniBlocks(t *testing.T) { expectedErr := errors.New("nil mini blocks") - container := mock.InitConsensusCore() - bp := mock.InitBlockProcessorMock(container.Marshalizer()) + container := consensusMock.InitConsensusCore() + bp := consensusMock.InitBlockProcessorMock(container.Marshalizer()) bp.CreateBlockCalled = func(header data.HeaderHandler, haveTime func() bool) (data.HeaderHandler, data.BodyHandler, error) { return nil, nil, expectedErr } - sr := *initSubroundBlockWithBlockProcessor(bp, container) + sr := initSubroundBlockWithBlockProcessor(bp, container) _ = sr.BlockChain().SetCurrentBlockHeaderAndRootHash(&block.Header{ Nonce: 1, }, []byte("root hash")) @@ -1059,7 +1073,7 @@ func TestSubroundBlock_ReceivedBlockComputeProcessDuration(t *testing.T) { srDuration := srEndTime - srStartTime delay := srDuration * 430 / 1000 - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() receivedValue := uint64(0) container.SetBlockProcessor(&testscommon.BlockProcessorStub{ ProcessBlockCalled: func(_ data.HeaderHandler, _ data.BodyHandler, _ func() time.Duration) error { @@ -1067,20 +1081,22 @@ func TestSubroundBlock_ReceivedBlockComputeProcessDuration(t *testing.T) { return nil }, }) - sr := *initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{ + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{ SetUInt64ValueHandler: func(key string, value uint64) { receivedValue = value }}) hdr := &block.Header{} blkBody := &block.Body{} - blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + blkBodyStr, _ := marshallerMock.MarshalizerMock{}.Marshal(blkBody) + leader, err := sr.GetLeader() + assert.Nil(t, err) cnsMsg := consensus.NewConsensusMessage( nil, nil, blkBodyStr, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtBlockBody), 0, @@ -1091,8 +1107,8 @@ func TestSubroundBlock_ReceivedBlockComputeProcessDuration(t *testing.T) { currentPid, nil, ) - sr.Header = hdr - sr.Body = blkBody + sr.SetHeader(hdr) + sr.SetBody(blkBody) minimumExpectedValue := uint64(delay * 100 / srDuration) _ = sr.ProcessReceivedBlock(cnsMsg) @@ -1113,13 +1129,13 @@ func TestSubroundBlock_ReceivedBlockComputeProcessDurationWithZeroDurationShould } }() - container := mock.InitConsensusCore() + container := consensusMock.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) - srBlock := *defaultSubroundBlockWithoutErrorFromSubround(sr) + srBlock := defaultSubroundBlockWithoutErrorFromSubround(sr) srBlock.ComputeSubroundProcessingMetric(time.Now(), "dummy") } diff --git a/consensus/spos/bls/subroundEndRound.go b/consensus/spos/bls/v1/subroundEndRound.go similarity index 90% rename from consensus/spos/bls/subroundEndRound.go rename to consensus/spos/bls/v1/subroundEndRound.go index 21675715f39..c591c736aca 100644 --- a/consensus/spos/bls/subroundEndRound.go +++ b/consensus/spos/bls/v1/subroundEndRound.go @@ -1,4 +1,4 @@ -package bls +package v1 import ( "bytes" @@ -11,9 +11,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/display" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process/headerCheck" ) @@ -73,7 +75,7 @@ func checkNewSubroundEndRoundParams( if baseSubround == nil { return spos.ErrNilSubround } - if baseSubround.ConsensusState == nil { + if check.IfNil(baseSubround.ConsensusStateHandler) { return spos.ErrNilConsensusState } @@ -131,11 +133,11 @@ func (sr *subroundEndRound) receivedBlockHeaderFinalInfo(_ context.Context, cnsD } func (sr *subroundEndRound) isBlockHeaderFinalInfoValid(cnsDta *consensus.Message) bool { - if check.IfNil(sr.Header) { + if check.IfNil(sr.GetHeader()) { return false } - header := sr.Header.ShallowClone() + header := sr.GetHeader().ShallowClone() err := header.SetPubKeysBitmap(cnsDta.PubKeysBitmap) if err != nil { log.Debug("isBlockHeaderFinalInfoValid.SetPubKeysBitmap", "error", err.Error()) @@ -293,14 +295,15 @@ func (sr *subroundEndRound) doEndRoundJob(_ context.Context) bool { } func (sr *subroundEndRound) doEndRoundJobByLeader() bool { - bitmap := sr.GenerateBitmap(SrSignature) + bitmap := sr.GenerateBitmap(bls.SrSignature) err := sr.checkSignaturesValidity(bitmap) if err != nil { log.Debug("doEndRoundJobByLeader.checkSignaturesValidity", "error", err.Error()) return false } - if check.IfNil(sr.Header) { + header := sr.GetHeader() + if check.IfNil(header) { log.Error("doEndRoundJobByLeader.CheckNilHeader", "error", spos.ErrNilHeader) return false } @@ -312,13 +315,13 @@ func (sr *subroundEndRound) doEndRoundJobByLeader() bool { return false } - err = sr.Header.SetPubKeysBitmap(bitmap) + err = header.SetPubKeysBitmap(bitmap) if err != nil { log.Debug("doEndRoundJobByLeader.SetPubKeysBitmap", "error", err.Error()) return false } - err = sr.Header.SetSignature(sig) + err = header.SetSignature(sig) if err != nil { log.Debug("doEndRoundJobByLeader.SetSignature", "error", err.Error()) return false @@ -331,7 +334,7 @@ func (sr *subroundEndRound) doEndRoundJobByLeader() bool { return false } - err = sr.Header.SetLeaderSignature(leaderSignature) + err = header.SetLeaderSignature(leaderSignature) if err != nil { log.Debug("doEndRoundJobByLeader.SetLeaderSignature", "error", err.Error()) return false @@ -362,13 +365,13 @@ func (sr *subroundEndRound) doEndRoundJobByLeader() bool { } // broadcast header - err = sr.BroadcastMessenger().BroadcastHeader(sr.Header, []byte(leader)) + err = sr.BroadcastMessenger().BroadcastHeader(header, []byte(leader)) if err != nil { log.Debug("doEndRoundJobByLeader.BroadcastHeader", "error", err.Error()) } startTime := time.Now() - err = sr.BlockProcessor().CommitBlock(sr.Header, sr.Body) + err = sr.BlockProcessor().CommitBlock(header, sr.GetBody()) elapsedTime := time.Since(startTime) if elapsedTime >= common.CommitMaxTime { log.Warn("doEndRoundJobByLeader.CommitBlock", "elapsed time", elapsedTime) @@ -393,7 +396,7 @@ func (sr *subroundEndRound) doEndRoundJobByLeader() bool { log.Debug("doEndRoundJobByLeader.broadcastBlockDataLeader", "error", err.Error()) } - msg := fmt.Sprintf("Added proposed block with nonce %d in blockchain", sr.Header.GetNonce()) + msg := fmt.Sprintf("Added proposed block with nonce %d in blockchain", header.GetNonce()) log.Debug(display.Headline(msg, sr.SyncTimer().FormattedCurrentTime(), "+")) sr.updateMetricsForLeader() @@ -402,7 +405,8 @@ func (sr *subroundEndRound) doEndRoundJobByLeader() bool { } func (sr *subroundEndRound) aggregateSigsAndHandleInvalidSigners(bitmap []byte) ([]byte, []byte, error) { - sig, err := sr.SigningHandler().AggregateSigs(bitmap, sr.Header.GetEpoch()) + header := sr.GetHeader() + sig, err := sr.SigningHandler().AggregateSigs(bitmap, header.GetEpoch()) if err != nil { log.Debug("doEndRoundJobByLeader.AggregateSigs", "error", err.Error()) @@ -415,7 +419,7 @@ func (sr *subroundEndRound) aggregateSigsAndHandleInvalidSigners(bitmap []byte) return nil, nil, err } - err = sr.SigningHandler().Verify(sr.GetData(), bitmap, sr.Header.GetEpoch()) + err = sr.SigningHandler().Verify(sr.GetData(), bitmap, header.GetEpoch()) if err != nil { log.Debug("doEndRoundJobByLeader.Verify", "error", err.Error()) @@ -429,12 +433,13 @@ func (sr *subroundEndRound) verifyNodesOnAggSigFail() ([]string, error) { invalidPubKeys := make([]string, 0) pubKeys := sr.ConsensusGroup() - if check.IfNil(sr.Header) { + header := sr.GetHeader() + if check.IfNil(header) { return nil, spos.ErrNilHeader } for i, pk := range pubKeys { - isJobDone, err := sr.JobDone(pk, SrSignature) + isJobDone, err := sr.JobDone(pk, bls.SrSignature) if err != nil || !isJobDone { continue } @@ -445,11 +450,11 @@ func (sr *subroundEndRound) verifyNodesOnAggSigFail() ([]string, error) { } isSuccessfull := true - err = sr.SigningHandler().VerifySignatureShare(uint16(i), sigShare, sr.GetData(), sr.Header.GetEpoch()) + err = sr.SigningHandler().VerifySignatureShare(uint16(i), sigShare, sr.GetData(), header.GetEpoch()) if err != nil { isSuccessfull = false - err = sr.SetJobDone(pk, SrSignature, false) + err = sr.SetJobDone(pk, bls.SrSignature, false) if err != nil { return nil, err } @@ -520,9 +525,10 @@ func (sr *subroundEndRound) handleInvalidSignersOnAggSigFail() ([]byte, []byte, func (sr *subroundEndRound) computeAggSigOnValidNodes() ([]byte, []byte, error) { threshold := sr.Threshold(sr.Current()) - numValidSigShares := sr.ComputeSize(SrSignature) + numValidSigShares := sr.ComputeSize(bls.SrSignature) - if check.IfNil(sr.Header) { + header := sr.GetHeader() + if check.IfNil(header) { return nil, nil, spos.ErrNilHeader } @@ -531,13 +537,13 @@ func (sr *subroundEndRound) computeAggSigOnValidNodes() ([]byte, []byte, error) spos.ErrInvalidNumSigShares, numValidSigShares, threshold) } - bitmap := sr.GenerateBitmap(SrSignature) + bitmap := sr.GenerateBitmap(bls.SrSignature) err := sr.checkSignaturesValidity(bitmap) if err != nil { return nil, nil, err } - sig, err := sr.SigningHandler().AggregateSigs(bitmap, sr.Header.GetEpoch()) + sig, err := sr.SigningHandler().AggregateSigs(bitmap, header.GetEpoch()) if err != nil { return nil, nil, err } @@ -557,6 +563,7 @@ func (sr *subroundEndRound) createAndBroadcastHeaderFinalInfo() { return } + header := sr.GetHeader() cnsMsg := consensus.NewConsensusMessage( sr.GetData(), nil, @@ -564,12 +571,12 @@ func (sr *subroundEndRound) createAndBroadcastHeaderFinalInfo() { nil, []byte(leader), nil, - int(MtBlockHeaderFinalInfo), + int(bls.MtBlockHeaderFinalInfo), sr.RoundHandler().Index(), sr.ChainID(), - sr.Header.GetPubKeysBitmap(), - sr.Header.GetSignature(), - sr.Header.GetLeaderSignature(), + header.GetPubKeysBitmap(), + header.GetSignature(), + header.GetLeaderSignature(), sr.GetAssociatedPid([]byte(leader)), nil, ) @@ -581,9 +588,9 @@ func (sr *subroundEndRound) createAndBroadcastHeaderFinalInfo() { } log.Debug("step 3: block header final info has been sent", - "PubKeysBitmap", sr.Header.GetPubKeysBitmap(), - "AggregateSignature", sr.Header.GetSignature(), - "LeaderSignature", sr.Header.GetLeaderSignature()) + "PubKeysBitmap", header.GetPubKeysBitmap(), + "AggregateSignature", header.GetSignature(), + "LeaderSignature", header.GetLeaderSignature()) } func (sr *subroundEndRound) createAndBroadcastInvalidSigners(invalidSigners []byte) { @@ -605,7 +612,7 @@ func (sr *subroundEndRound) createAndBroadcastInvalidSigners(invalidSigners []by nil, []byte(leader), nil, - int(MtInvalidSigners), + int(bls.MtInvalidSigners), sr.RoundHandler().Index(), sr.ChainID(), nil, @@ -628,7 +635,7 @@ func (sr *subroundEndRound) doEndRoundJobByParticipant(cnsDta *consensus.Message sr.mutProcessingEndRound.Lock() defer sr.mutProcessingEndRound.Unlock() - if sr.RoundCanceled { + if sr.GetRoundCanceled() { return false } if !sr.IsConsensusDataSet() { @@ -652,13 +659,13 @@ func (sr *subroundEndRound) doEndRoundJobByParticipant(cnsDta *consensus.Message sr.SetProcessingBlock(true) - shouldNotCommitBlock := sr.ExtendedCalled || int64(header.GetRound()) < sr.RoundHandler().Index() + shouldNotCommitBlock := sr.GetExtendedCalled() || int64(header.GetRound()) < sr.RoundHandler().Index() if shouldNotCommitBlock { log.Debug("canceled round, extended has been called or round index has been changed", "round", sr.RoundHandler().Index(), "subround", sr.Name(), "header round", header.GetRound(), - "extended called", sr.ExtendedCalled, + "extended called", sr.GetExtendedCalled(), ) return false } @@ -673,7 +680,7 @@ func (sr *subroundEndRound) doEndRoundJobByParticipant(cnsDta *consensus.Message } startTime := time.Now() - err := sr.BlockProcessor().CommitBlock(header, sr.Body) + err := sr.BlockProcessor().CommitBlock(header, sr.GetBody()) elapsedTime := time.Since(startTime) if elapsedTime >= common.CommitMaxTime { log.Warn("doEndRoundJobByParticipant.CommitBlock", "elapsed time", elapsedTime) @@ -715,11 +722,11 @@ func (sr *subroundEndRound) haveConsensusHeaderWithFullInfo(cnsDta *consensus.Me return sr.isConsensusHeaderReceived() } - if check.IfNil(sr.Header) { + if check.IfNil(sr.GetHeader()) { return false, nil } - header := sr.Header.ShallowClone() + header := sr.GetHeader().ShallowClone() err := header.SetPubKeysBitmap(cnsDta.PubKeysBitmap) if err != nil { return false, nil @@ -739,11 +746,11 @@ func (sr *subroundEndRound) haveConsensusHeaderWithFullInfo(cnsDta *consensus.Me } func (sr *subroundEndRound) isConsensusHeaderReceived() (bool, data.HeaderHandler) { - if check.IfNil(sr.Header) { + if check.IfNil(sr.GetHeader()) { return false, nil } - consensusHeaderHash, err := core.CalculateHash(sr.Marshalizer(), sr.Hasher(), sr.Header) + consensusHeaderHash, err := core.CalculateHash(sr.Marshalizer(), sr.Hasher(), sr.GetHeader()) if err != nil { log.Debug("isConsensusHeaderReceived: calculate consensus header hash", "error", err.Error()) return false, nil @@ -787,7 +794,7 @@ func (sr *subroundEndRound) isConsensusHeaderReceived() (bool, data.HeaderHandle } func (sr *subroundEndRound) signBlockHeader() ([]byte, error) { - headerClone := sr.Header.ShallowClone() + headerClone := sr.GetHeader().ShallowClone() err := headerClone.SetLeaderSignature(nil) if err != nil { return nil, err @@ -813,7 +820,7 @@ func (sr *subroundEndRound) updateMetricsForLeader() { } func (sr *subroundEndRound) broadcastBlockDataLeader() error { - miniBlocks, transactions, err := sr.BlockProcessor().MarshalizedDataToBroadcast(sr.Header, sr.Body) + miniBlocks, transactions, err := sr.BlockProcessor().MarshalizedDataToBroadcast(sr.GetHeader(), sr.GetBody()) if err != nil { return err } @@ -824,7 +831,7 @@ func (sr *subroundEndRound) broadcastBlockDataLeader() error { return errGetLeader } - return sr.BroadcastMessenger().BroadcastBlockDataLeader(sr.Header, miniBlocks, transactions, []byte(leader)) + return sr.BroadcastMessenger().BroadcastBlockDataLeader(sr.GetHeader(), miniBlocks, transactions, []byte(leader)) } func (sr *subroundEndRound) setHeaderForValidator(header data.HeaderHandler) error { @@ -844,14 +851,14 @@ func (sr *subroundEndRound) prepareBroadcastBlockDataForValidator() error { return err } - go sr.BroadcastMessenger().PrepareBroadcastBlockDataValidator(sr.Header, miniBlocks, transactions, idx, pk) + go sr.BroadcastMessenger().PrepareBroadcastBlockDataValidator(sr.GetHeader(), miniBlocks, transactions, idx, pk) return nil } // doEndRoundConsensusCheck method checks if the consensus is achieved func (sr *subroundEndRound) doEndRoundConsensusCheck() bool { - if sr.RoundCanceled { + if sr.GetRoundCanceled() { return false } @@ -866,7 +873,7 @@ func (sr *subroundEndRound) checkSignaturesValidity(bitmap []byte) error { consensusGroup := sr.ConsensusGroup() signers := headerCheck.ComputeSignersPublicKeys(consensusGroup, bitmap) for _, pubKey := range signers { - isSigJobDone, err := sr.JobDone(pubKey, SrSignature) + isSigJobDone, err := sr.JobDone(pubKey, bls.SrSignature) if err != nil { return err } @@ -880,14 +887,14 @@ func (sr *subroundEndRound) checkSignaturesValidity(bitmap []byte) error { } func (sr *subroundEndRound) isOutOfTime() bool { - startTime := sr.RoundTimeStamp + startTime := sr.GetRoundTimeStamp() maxTime := sr.RoundHandler().TimeDuration() * time.Duration(sr.processingThresholdPercentage) / 100 if sr.RoundHandler().RemainingTime(startTime, maxTime) < 0 { log.Debug("canceled round, time is out", "round", sr.SyncTimer().FormattedCurrentTime(), sr.RoundHandler().Index(), "subround", sr.Name()) - sr.RoundCanceled = true + sr.SetRoundCanceled(true) return true } @@ -908,7 +915,7 @@ func (sr *subroundEndRound) getIndexPkAndDataToBroadcast() (int, []byte, map[uin return -1, nil, nil, nil, err } - miniBlocks, transactions, err := sr.BlockProcessor().MarshalizedDataToBroadcast(sr.Header, sr.Body) + miniBlocks, transactions, err := sr.BlockProcessor().MarshalizedDataToBroadcast(sr.GetHeader(), sr.GetBody()) if err != nil { return -1, nil, nil, nil, err } @@ -923,7 +930,7 @@ func (sr *subroundEndRound) getMinConsensusGroupIndexOfManagedKeys() int { minIdx := sr.ConsensusGroupSize() for idx, validator := range sr.ConsensusGroup() { - if !sr.IsKeyManagedByCurrentNode([]byte(validator)) { + if !sr.IsKeyManagedBySelf([]byte(validator)) { continue } diff --git a/consensus/spos/bls/subroundEndRound_test.go b/consensus/spos/bls/v1/subroundEndRound_test.go similarity index 77% rename from consensus/spos/bls/subroundEndRound_test.go rename to consensus/spos/bls/v1/subroundEndRound_test.go index 725513b8cb2..678f70b7d78 100644 --- a/consensus/spos/bls/subroundEndRound_test.go +++ b/consensus/spos/bls/v1/subroundEndRound_test.go @@ -1,4 +1,4 @@ -package bls_test +package v1_test import ( "bytes" @@ -12,27 +12,30 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v1 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v1" "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/p2p/factory" "github.com/multiversx/mx-chain-go/testscommon" consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func initSubroundEndRoundWithContainer( - container *mock.ConsensusCoreMock, + container *spos.ConsensusCore, appStatusHandler core.AppStatusHandler, -) bls.SubroundEndRound { +) v1.SubroundEndRound { ch := make(chan bool, 1) - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() sr, _ := spos.NewSubround( bls.SrSignature, bls.SrEndRound, @@ -49,10 +52,10 @@ func initSubroundEndRoundWithContainer( appStatusHandler, ) - srEndRound, _ := bls.NewSubroundEndRound( + srEndRound, _ := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, appStatusHandler, &testscommon.SentSignatureTrackerStub{}, @@ -61,16 +64,16 @@ func initSubroundEndRoundWithContainer( return srEndRound } -func initSubroundEndRound(appStatusHandler core.AppStatusHandler) bls.SubroundEndRound { - container := mock.InitConsensusCore() +func initSubroundEndRound(appStatusHandler core.AppStatusHandler) v1.SubroundEndRound { + container := consensusMocks.InitConsensusCore() return initSubroundEndRoundWithContainer(container, appStatusHandler) } func TestNewSubroundEndRound(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( bls.SrSignature, @@ -91,10 +94,10 @@ func TestNewSubroundEndRound(t *testing.T) { t.Run("nil subround should error", func(t *testing.T) { t.Parallel() - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( nil, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -106,10 +109,10 @@ func TestNewSubroundEndRound(t *testing.T) { t.Run("nil extend function handler should error", func(t *testing.T) { t.Parallel() - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, nil, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -121,10 +124,10 @@ func TestNewSubroundEndRound(t *testing.T) { t.Run("nil app status handler should error", func(t *testing.T) { t.Parallel() - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, nil, &testscommon.SentSignatureTrackerStub{}, @@ -136,25 +139,25 @@ func TestNewSubroundEndRound(t *testing.T) { t.Run("nil sent signatures tracker should error", func(t *testing.T) { t.Parallel() - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, nil, ) assert.Nil(t, srEndRound) - assert.Equal(t, bls.ErrNilSentSignatureTracker, err) + assert.Equal(t, v1.ErrNilSentSignatureTracker, err) }) } func TestSubroundEndRound_NewSubroundEndRoundNilBlockChainShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -173,10 +176,10 @@ func TestSubroundEndRound_NewSubroundEndRoundNilBlockChainShouldFail(t *testing. &statusHandler.AppStatusHandlerStub{}, ) container.SetBlockchain(nil) - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -189,8 +192,8 @@ func TestSubroundEndRound_NewSubroundEndRoundNilBlockChainShouldFail(t *testing. func TestSubroundEndRound_NewSubroundEndRoundNilBlockProcessorShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -209,10 +212,10 @@ func TestSubroundEndRound_NewSubroundEndRoundNilBlockProcessorShouldFail(t *test &statusHandler.AppStatusHandlerStub{}, ) container.SetBlockProcessor(nil) - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -225,8 +228,8 @@ func TestSubroundEndRound_NewSubroundEndRoundNilBlockProcessorShouldFail(t *test func TestSubroundEndRound_NewSubroundEndRoundNilConsensusStateShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -245,11 +248,11 @@ func TestSubroundEndRound_NewSubroundEndRoundNilConsensusStateShouldFail(t *test &statusHandler.AppStatusHandlerStub{}, ) - sr.ConsensusState = nil - srEndRound, err := bls.NewSubroundEndRound( + sr.ConsensusStateHandler = nil + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -262,8 +265,8 @@ func TestSubroundEndRound_NewSubroundEndRoundNilConsensusStateShouldFail(t *test func TestSubroundEndRound_NewSubroundEndRoundNilMultiSignerContainerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -282,10 +285,10 @@ func TestSubroundEndRound_NewSubroundEndRoundNilMultiSignerContainerShouldFail(t &statusHandler.AppStatusHandlerStub{}, ) container.SetMultiSignerContainer(nil) - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -298,8 +301,8 @@ func TestSubroundEndRound_NewSubroundEndRoundNilMultiSignerContainerShouldFail(t func TestSubroundEndRound_NewSubroundEndRoundNilRoundHandlerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -318,10 +321,10 @@ func TestSubroundEndRound_NewSubroundEndRoundNilRoundHandlerShouldFail(t *testin &statusHandler.AppStatusHandlerStub{}, ) container.SetRoundHandler(nil) - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -334,8 +337,8 @@ func TestSubroundEndRound_NewSubroundEndRoundNilRoundHandlerShouldFail(t *testin func TestSubroundEndRound_NewSubroundEndRoundNilSyncTimerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -354,10 +357,10 @@ func TestSubroundEndRound_NewSubroundEndRoundNilSyncTimerShouldFail(t *testing.T &statusHandler.AppStatusHandlerStub{}, ) container.SetSyncTimer(nil) - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -370,8 +373,8 @@ func TestSubroundEndRound_NewSubroundEndRoundNilSyncTimerShouldFail(t *testing.T func TestSubroundEndRound_NewSubroundEndRoundShouldWork(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -390,10 +393,10 @@ func TestSubroundEndRound_NewSubroundEndRoundShouldWork(t *testing.T) { &statusHandler.AppStatusHandlerStub{}, ) - srEndRound, err := bls.NewSubroundEndRound( + srEndRound, err := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -405,8 +408,8 @@ func TestSubroundEndRound_NewSubroundEndRoundShouldWork(t *testing.T) { func TestSubroundEndRound_DoEndRoundJobErrAggregatingSigShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) signingHandler := &consensusMocks.SigningHandlerStub{ AggregateSigsCalled: func(bitmap []byte, epoch uint32) ([]byte, error) { @@ -415,9 +418,10 @@ func TestSubroundEndRound_DoEndRoundJobErrAggregatingSigShouldFail(t *testing.T) } container.SetSigningHandler(signingHandler) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") assert.True(t, sr.IsSelfLeaderInCurrentRound()) r := sr.DoEndRoundJob() @@ -427,11 +431,12 @@ func TestSubroundEndRound_DoEndRoundJobErrAggregatingSigShouldFail(t *testing.T) func TestSubroundEndRound_DoEndRoundJobErrCommitBlockShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - blProcMock := mock.InitBlockProcessorMock(container.Marshalizer()) + blProcMock := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) blProcMock.CommitBlockCalled = func( header data.HeaderHandler, body data.BodyHandler, @@ -440,7 +445,7 @@ func TestSubroundEndRound_DoEndRoundJobErrCommitBlockShouldFail(t *testing.T) { } container.SetBlockProcessor(blProcMock) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.False(t, r) @@ -449,19 +454,20 @@ func TestSubroundEndRound_DoEndRoundJobErrCommitBlockShouldFail(t *testing.T) { func TestSubroundEndRound_DoEndRoundJobErrTimeIsOutShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") remainingTime := time.Millisecond - roundHandlerMock := &mock.RoundHandlerMock{ + roundHandlerMock := &consensusMocks.RoundHandlerMock{ RemainingTimeCalled: func(startTime time.Time, maxTime time.Duration) time.Duration { return remainingTime }, } container.SetRoundHandler(roundHandlerMock) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.True(t, r) @@ -475,17 +481,18 @@ func TestSubroundEndRound_DoEndRoundJobErrTimeIsOutShouldFail(t *testing.T) { func TestSubroundEndRound_DoEndRoundJobErrBroadcastBlockOK(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - bm := &mock.BroadcastMessengerMock{ + container := consensusMocks.InitConsensusCore() + bm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return errors.New("error") }, } container.SetBroadcastMessenger(bm) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.True(t, r) @@ -495,16 +502,16 @@ func TestSubroundEndRound_DoEndRoundJobErrMarshalizedDataToBroadcastOK(t *testin t.Parallel() err := errors.New("") - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - bpm := mock.InitBlockProcessorMock(container.Marshalizer()) + bpm := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) bpm.MarshalizedDataToBroadcastCalled = func(header data.HeaderHandler, body data.BodyHandler) (map[uint32][]byte, map[string][][]byte, error) { err = errors.New("error marshalized data to broadcast") return make(map[uint32][]byte), make(map[string][][]byte), err } container.SetBlockProcessor(bpm) - bm := &mock.BroadcastMessengerMock{ + bm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return nil }, @@ -516,10 +523,11 @@ func TestSubroundEndRound_DoEndRoundJobErrMarshalizedDataToBroadcastOK(t *testin }, } container.SetBroadcastMessenger(bm) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.True(t, r) @@ -530,15 +538,15 @@ func TestSubroundEndRound_DoEndRoundJobErrBroadcastMiniBlocksOK(t *testing.T) { t.Parallel() err := errors.New("") - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - bpm := mock.InitBlockProcessorMock(container.Marshalizer()) + bpm := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) bpm.MarshalizedDataToBroadcastCalled = func(header data.HeaderHandler, body data.BodyHandler) (map[uint32][]byte, map[string][][]byte, error) { return make(map[uint32][]byte), make(map[string][][]byte), nil } container.SetBlockProcessor(bpm) - bm := &mock.BroadcastMessengerMock{ + bm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return nil }, @@ -551,10 +559,11 @@ func TestSubroundEndRound_DoEndRoundJobErrBroadcastMiniBlocksOK(t *testing.T) { }, } container.SetBroadcastMessenger(bm) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.True(t, r) @@ -566,15 +575,15 @@ func TestSubroundEndRound_DoEndRoundJobErrBroadcastTransactionsOK(t *testing.T) t.Parallel() err := errors.New("") - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - bpm := mock.InitBlockProcessorMock(container.Marshalizer()) + bpm := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) bpm.MarshalizedDataToBroadcastCalled = func(header data.HeaderHandler, body data.BodyHandler) (map[uint32][]byte, map[string][][]byte, error) { return make(map[uint32][]byte), make(map[string][][]byte), nil } container.SetBlockProcessor(bpm) - bm := &mock.BroadcastMessengerMock{ + bm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return nil }, @@ -587,10 +596,11 @@ func TestSubroundEndRound_DoEndRoundJobErrBroadcastTransactionsOK(t *testing.T) }, } container.SetBroadcastMessenger(bm) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.True(t, r) @@ -601,17 +611,18 @@ func TestSubroundEndRound_DoEndRoundJobErrBroadcastTransactionsOK(t *testing.T) func TestSubroundEndRound_DoEndRoundJobAllOK(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - bm := &mock.BroadcastMessengerMock{ + container := consensusMocks.InitConsensusCore() + bm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return errors.New("error") }, } container.SetBroadcastMessenger(bm) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJob() assert.True(t, r) @@ -621,7 +632,7 @@ func TestSubroundEndRound_CheckIfSignatureIsFilled(t *testing.T) { t.Parallel() expectedSignature := []byte("signature") - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() signingHandler := &consensusMocks.SigningHandlerStub{ CreateSignatureForPublicKeyCalled: func(publicKeyBytes []byte, msg []byte) ([]byte, error) { var receivedHdr block.Header @@ -630,27 +641,28 @@ func TestSubroundEndRound_CheckIfSignatureIsFilled(t *testing.T) { }, } container.SetSigningHandler(signingHandler) - bm := &mock.BroadcastMessengerMock{ + bm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { return errors.New("error") }, } container.SetBroadcastMessenger(bm) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") - sr.Header = &block.Header{Nonce: 5} + sr.SetHeader(&block.Header{Nonce: 5}) r := sr.DoEndRoundJob() assert.True(t, r) - assert.Equal(t, expectedSignature, sr.Header.GetLeaderSignature()) + assert.Equal(t, expectedSignature, sr.GetHeader().GetLeaderSignature()) } func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.RoundCanceled = true + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetRoundCanceled(true) ok := sr.DoEndRoundConsensusCheck() assert.False(t, ok) @@ -659,7 +671,7 @@ func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnFalseWhenRoundIsCa func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnTrueWhenRoundIsFinished(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) sr.SetStatus(bls.SrEndRound, spos.SsFinished) ok := sr.DoEndRoundConsensusCheck() @@ -669,7 +681,7 @@ func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnTrueWhenRoundIsFin func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnFalseWhenRoundIsNotFinished(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) ok := sr.DoEndRoundConsensusCheck() assert.False(t, ok) @@ -678,7 +690,7 @@ func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnFalseWhenRoundIsNo func TestSubroundEndRound_CheckSignaturesValidityShouldErrNilSignature(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) err := sr.CheckSignaturesValidity([]byte{2}) assert.Equal(t, spos.ErrNilSignature, err) @@ -687,7 +699,7 @@ func TestSubroundEndRound_CheckSignaturesValidityShouldErrNilSignature(t *testin func TestSubroundEndRound_CheckSignaturesValidityShouldReturnNil(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) @@ -698,8 +710,8 @@ func TestSubroundEndRound_CheckSignaturesValidityShouldReturnNil(t *testing.T) { func TestSubroundEndRound_DoEndRoundJobByParticipant_RoundCanceledShouldReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.RoundCanceled = true + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetRoundCanceled(true) cnsData := consensus.Message{} res := sr.DoEndRoundJobByParticipant(&cnsData) @@ -709,8 +721,8 @@ func TestSubroundEndRound_DoEndRoundJobByParticipant_RoundCanceledShouldReturnFa func TestSubroundEndRound_DoEndRoundJobByParticipant_ConsensusDataNotSetShouldReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.Data = nil + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetData(nil) cnsData := consensus.Message{} res := sr.DoEndRoundJobByParticipant(&cnsData) @@ -720,7 +732,7 @@ func TestSubroundEndRound_DoEndRoundJobByParticipant_ConsensusDataNotSetShouldRe func TestSubroundEndRound_DoEndRoundJobByParticipant_PreviousSubroundNotFinishedShouldReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) sr.SetStatus(2, spos.SsNotFinished) cnsData := consensus.Message{} res := sr.DoEndRoundJobByParticipant(&cnsData) @@ -730,7 +742,7 @@ func TestSubroundEndRound_DoEndRoundJobByParticipant_PreviousSubroundNotFinished func TestSubroundEndRound_DoEndRoundJobByParticipant_CurrentSubroundFinishedShouldReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) // set previous as finished sr.SetStatus(2, spos.SsFinished) @@ -746,7 +758,7 @@ func TestSubroundEndRound_DoEndRoundJobByParticipant_CurrentSubroundFinishedShou func TestSubroundEndRound_DoEndRoundJobByParticipant_ConsensusHeaderNotReceivedShouldReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) // set previous as finished sr.SetStatus(2, spos.SsFinished) @@ -763,8 +775,8 @@ func TestSubroundEndRound_DoEndRoundJobByParticipant_ShouldReturnTrue(t *testing t.Parallel() hdr := &block.Header{Nonce: 37} - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.Header = hdr + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(hdr) sr.AddReceivedHeader(hdr) // set previous as finished @@ -782,8 +794,8 @@ func TestSubroundEndRound_IsConsensusHeaderReceived_NoReceivedHeadersShouldRetur t.Parallel() hdr := &block.Header{Nonce: 37} - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.Header = hdr + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(hdr) res, retHdr := sr.IsConsensusHeaderReceived() assert.False(t, res) @@ -795,9 +807,9 @@ func TestSubroundEndRound_IsConsensusHeaderReceived_HeaderNotReceivedShouldRetur hdr := &block.Header{Nonce: 37} hdrToSearchFor := &block.Header{Nonce: 38} - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) sr.AddReceivedHeader(hdr) - sr.Header = hdrToSearchFor + sr.SetHeader(hdrToSearchFor) res, retHdr := sr.IsConsensusHeaderReceived() assert.False(t, res) @@ -808,8 +820,8 @@ func TestSubroundEndRound_IsConsensusHeaderReceivedShouldReturnTrue(t *testing.T t.Parallel() hdr := &block.Header{Nonce: 37} - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.Header = hdr + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(hdr) sr.AddReceivedHeader(hdr) res, retHdr := sr.IsConsensusHeaderReceived() @@ -820,7 +832,7 @@ func TestSubroundEndRound_IsConsensusHeaderReceivedShouldReturnTrue(t *testing.T func TestSubroundEndRound_HaveConsensusHeaderWithFullInfoNilHdrShouldNotWork(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{} @@ -843,8 +855,8 @@ func TestSubroundEndRound_HaveConsensusHeaderWithFullInfoShouldWork(t *testing.T Signature: originalSig, LeaderSignature: originalLeaderSig, } - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.Header = &hdr + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&hdr) cnsData := consensus.Message{ PubKeysBitmap: newPubKeyBitMap, @@ -864,8 +876,8 @@ func TestSubroundEndRound_CreateAndBroadcastHeaderFinalInfoBroadcastShouldBeCall chanRcv := make(chan bool, 1) leaderSigInHdr := []byte("leader sig") - container := mock.InitConsensusCore() - messenger := &mock.BroadcastMessengerMock{ + container := consensusMocks.InitConsensusCore() + messenger := &consensusMocks.BroadcastMessengerMock{ BroadcastConsensusMessageCalled: func(message *consensus.Message) error { chanRcv <- true assert.Equal(t, message.LeaderSignature, leaderSigInHdr) @@ -873,8 +885,8 @@ func TestSubroundEndRound_CreateAndBroadcastHeaderFinalInfoBroadcastShouldBeCall }, } container.SetBroadcastMessenger(messenger) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) - sr.Header = &block.Header{LeaderSignature: leaderSigInHdr} + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{LeaderSignature: leaderSigInHdr}) sr.CreateAndBroadcastHeaderFinalInfo() @@ -889,8 +901,8 @@ func TestSubroundEndRound_ReceivedBlockHeaderFinalInfoShouldWork(t *testing.T) { t.Parallel() hdr := &block.Header{Nonce: 37} - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) - sr.Header = hdr + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(hdr) sr.AddReceivedHeader(hdr) sr.SetStatus(2, spos.SsFinished) @@ -909,9 +921,9 @@ func TestSubroundEndRound_ReceivedBlockHeaderFinalInfoShouldWork(t *testing.T) { func TestSubroundEndRound_ReceivedBlockHeaderFinalInfoShouldReturnFalseWhenFinalInfoIsNotValid(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - headerSigVerifier := &mock.HeaderSigVerifierStub{ + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{ VerifyLeaderSignatureCalled: func(header data.HeaderHandler) error { return errors.New("error") }, @@ -921,12 +933,12 @@ func TestSubroundEndRound_ReceivedBlockHeaderFinalInfoShouldReturnFalseWhenFinal } container.SetHeaderSigVerifier(headerSigVerifier) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), PubKey: []byte("A"), } - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) res := sr.ReceivedBlockHeaderFinalInfo(&cnsData) assert.False(t, res) } @@ -934,7 +946,7 @@ func TestSubroundEndRound_ReceivedBlockHeaderFinalInfoShouldReturnFalseWhenFinal func TestSubroundEndRound_IsOutOfTimeShouldReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) res := sr.IsOutOfTime() assert.False(t, res) @@ -944,8 +956,8 @@ func TestSubroundEndRound_IsOutOfTimeShouldReturnTrue(t *testing.T) { t.Parallel() // update roundHandler's mock, so it will calculate for real the duration - container := mock.InitConsensusCore() - roundHandler := mock.RoundHandlerMock{RemainingTimeCalled: func(startTime time.Time, maxTime time.Duration) time.Duration { + container := consensusMocks.InitConsensusCore() + roundHandler := consensusMocks.RoundHandlerMock{RemainingTimeCalled: func(startTime time.Time, maxTime time.Duration) time.Duration { currentTime := time.Now() elapsedTime := currentTime.Sub(startTime) remainingTime := maxTime - elapsedTime @@ -953,9 +965,9 @@ func TestSubroundEndRound_IsOutOfTimeShouldReturnTrue(t *testing.T) { return remainingTime }} container.SetRoundHandler(&roundHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) - sr.RoundTimeStamp = time.Now().AddDate(0, 0, -1) + sr.SetRoundTimeStamp(time.Now().AddDate(0, 0, -1)) res := sr.IsOutOfTime() assert.True(t, res) @@ -964,9 +976,9 @@ func TestSubroundEndRound_IsOutOfTimeShouldReturnTrue(t *testing.T) { func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnFalseWhenVerifyLeaderSignatureFails(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - headerSigVerifier := &mock.HeaderSigVerifierStub{ + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{ VerifyLeaderSignatureCalled: func(header data.HeaderHandler) error { return errors.New("error") }, @@ -976,9 +988,9 @@ func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnFalseWhenVerify } container.SetHeaderSigVerifier(headerSigVerifier) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsDta := &consensus.Message{} - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) isValid := sr.IsBlockHeaderFinalInfoValid(cnsDta) assert.False(t, isValid) } @@ -986,9 +998,9 @@ func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnFalseWhenVerify func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnFalseWhenVerifySignatureFails(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - headerSigVerifier := &mock.HeaderSigVerifierStub{ + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{ VerifyLeaderSignatureCalled: func(header data.HeaderHandler) error { return nil }, @@ -998,9 +1010,9 @@ func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnFalseWhenVerify } container.SetHeaderSigVerifier(headerSigVerifier) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsDta := &consensus.Message{} - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) isValid := sr.IsBlockHeaderFinalInfoValid(cnsDta) assert.False(t, isValid) } @@ -1008,9 +1020,9 @@ func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnFalseWhenVerify func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnTrue(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - headerSigVerifier := &mock.HeaderSigVerifierStub{ + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{ VerifyLeaderSignatureCalled: func(header data.HeaderHandler) error { return nil }, @@ -1020,9 +1032,9 @@ func TestSubroundEndRound_IsBlockHeaderFinalInfoValidShouldReturnTrue(t *testing } container.SetHeaderSigVerifier(headerSigVerifier) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsDta := &consensus.Message{} - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) isValid := sr.IsBlockHeaderFinalInfoValid(cnsDta) assert.True(t, isValid) } @@ -1033,8 +1045,8 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { t.Run("fail to get signature share", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) expectedErr := errors.New("exptected error") signingHandler := &consensusMocks.SigningHandlerStub{ @@ -1045,7 +1057,7 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { container.SetSigningHandler(signingHandler) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) _, err := sr.VerifyNodesOnAggSigFail() @@ -1055,8 +1067,8 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { t.Run("fail to verify signature share, job done will be set to false", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) expectedErr := errors.New("exptected error") signingHandler := &consensusMocks.SigningHandlerStub{ @@ -1068,7 +1080,7 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { }, } - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) container.SetSigningHandler(signingHandler) @@ -1083,8 +1095,8 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { t.Run("should work", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) signingHandler := &consensusMocks.SigningHandlerStub{ SignatureShareCalled: func(index uint16) ([]byte, error) { return nil, nil @@ -1098,7 +1110,7 @@ func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { } container.SetSigningHandler(signingHandler) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) _ = sr.SetJobDone(sr.ConsensusGroup()[1], bls.SrSignature, true) @@ -1114,9 +1126,9 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { t.Run("invalid number of valid sig shares", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) - sr.Header = &block.Header{} + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{}) sr.SetThreshold(bls.SrEndRound, 2) _, _, err := sr.ComputeAggSigOnValidNodes() @@ -1126,8 +1138,8 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { t.Run("fail to created aggregated sig", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) expectedErr := errors.New("exptected error") signingHandler := &consensusMocks.SigningHandlerStub{ @@ -1137,7 +1149,7 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { } container.SetSigningHandler(signingHandler) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) _, _, err := sr.ComputeAggSigOnValidNodes() @@ -1147,8 +1159,8 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { t.Run("fail to set aggregated sig", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) expectedErr := errors.New("exptected error") signingHandler := &consensusMocks.SigningHandlerStub{ @@ -1157,7 +1169,7 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { }, } container.SetSigningHandler(signingHandler) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) _, _, err := sr.ComputeAggSigOnValidNodes() @@ -1167,9 +1179,9 @@ func TestComputeAddSigOnValidNodes(t *testing.T) { t.Run("should work", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) - sr.Header = &block.Header{} + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{}) _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) bitmap, sig, err := sr.ComputeAggSigOnValidNodes() @@ -1185,8 +1197,8 @@ func TestSubroundEndRound_DoEndRoundJobByLeaderVerificationFail(t *testing.T) { t.Run("not enough valid signature shares", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) verifySigShareNumCalls := 0 verifyFirstCall := true @@ -1220,7 +1232,7 @@ func TestSubroundEndRound_DoEndRoundJobByLeaderVerificationFail(t *testing.T) { _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) _ = sr.SetJobDone(sr.ConsensusGroup()[1], bls.SrSignature, true) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJobByLeader() require.False(t, r) @@ -1232,8 +1244,8 @@ func TestSubroundEndRound_DoEndRoundJobByLeaderVerificationFail(t *testing.T) { t.Run("should work", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) verifySigShareNumCalls := 0 verifyFirstCall := true @@ -1268,7 +1280,7 @@ func TestSubroundEndRound_DoEndRoundJobByLeaderVerificationFail(t *testing.T) { _ = sr.SetJobDone(sr.ConsensusGroup()[1], bls.SrSignature, true) _ = sr.SetJobDone(sr.ConsensusGroup()[2], bls.SrSignature, true) - sr.Header = &block.Header{} + sr.SetHeader(&block.Header{}) r := sr.DoEndRoundJobByLeader() require.True(t, r) @@ -1284,10 +1296,10 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("consensus data is not set", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) - sr.ConsensusState.Data = nil + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.ConsensusStateHandler.SetData(nil) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), @@ -1301,9 +1313,9 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("received message node is not leader in current round", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), @@ -1317,10 +1329,11 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("received message from self leader should return false", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), @@ -1334,14 +1347,14 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("received message from self multikey leader should return false", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{ IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { return string(pkBytes) == "A" }, } ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) sr, _ := spos.NewSubround( bls.SrSignature, bls.SrEndRound, @@ -1358,10 +1371,10 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { &statusHandler.AppStatusHandlerStub{}, ) - srEndRound, _ := bls.NewSubroundEndRound( + srEndRound, _ := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, @@ -1381,9 +1394,9 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("received hash does not match the hash from current consensus state", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("Y"), @@ -1397,9 +1410,9 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("process received message verification failed, different round index", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), @@ -1414,9 +1427,9 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("empty invalid signers", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), PubKey: []byte("A"), @@ -1437,10 +1450,10 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { }, } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), PubKey: []byte("A"), @@ -1454,9 +1467,9 @@ func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { t.Run("should work", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) cnsData := consensus.Message{ BlockHeaderHash: []byte("X"), @@ -1475,7 +1488,7 @@ func TestVerifyInvalidSigners(t *testing.T) { t.Run("failed to deserialize invalidSigners field, should error", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() expectedErr := errors.New("expected err") messageSigningHandler := &mock.MessageSigningHandlerStub{ @@ -1486,7 +1499,7 @@ func TestVerifyInvalidSigners(t *testing.T) { container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) err := sr.VerifyInvalidSigners([]byte{}) require.Equal(t, expectedErr, err) @@ -1495,7 +1508,7 @@ func TestVerifyInvalidSigners(t *testing.T) { t.Run("failed to verify low level p2p message, should error", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() invalidSigners := []p2p.MessageP2P{&factory.Message{ FromField: []byte("from"), @@ -1515,7 +1528,7 @@ func TestVerifyInvalidSigners(t *testing.T) { container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) err := sr.VerifyInvalidSigners(invalidSignersBytes) require.Equal(t, expectedErr, err) @@ -1524,7 +1537,7 @@ func TestVerifyInvalidSigners(t *testing.T) { t.Run("failed to verify signature share", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() pubKey := []byte("A") // it's in consensus @@ -1557,7 +1570,7 @@ func TestVerifyInvalidSigners(t *testing.T) { container.SetSigningHandler(signingHandler) container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) err := sr.VerifyInvalidSigners(invalidSignersBytes) require.Nil(t, err) @@ -1567,7 +1580,7 @@ func TestVerifyInvalidSigners(t *testing.T) { t.Run("should work", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() pubKey := []byte("A") // it's in consensus @@ -1585,7 +1598,7 @@ func TestVerifyInvalidSigners(t *testing.T) { messageSigningHandler := &mock.MessageSignerMock{} container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) err := sr.VerifyInvalidSigners(invalidSignersBytes) require.Nil(t, err) @@ -1600,7 +1613,7 @@ func TestSubroundEndRound_CreateAndBroadcastInvalidSigners(t *testing.T) { expectedInvalidSigners := []byte("invalid signers") - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() nodeRedundancy := &mock.NodeRedundancyHandlerStub{ IsRedundancyNodeCalled: func() bool { return true @@ -1610,14 +1623,14 @@ func TestSubroundEndRound_CreateAndBroadcastInvalidSigners(t *testing.T) { }, } container.SetNodeRedundancyHandler(nodeRedundancy) - messenger := &mock.BroadcastMessengerMock{ + messenger := &consensusMocks.BroadcastMessengerMock{ BroadcastConsensusMessageCalled: func(message *consensus.Message) error { assert.Fail(t, "should have not been called") return nil }, } container.SetBroadcastMessenger(messenger) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.CreateAndBroadcastInvalidSigners(expectedInvalidSigners) }) @@ -1630,8 +1643,8 @@ func TestSubroundEndRound_CreateAndBroadcastInvalidSigners(t *testing.T) { expectedInvalidSigners := []byte("invalid signers") wasCalled := false - container := mock.InitConsensusCore() - messenger := &mock.BroadcastMessengerMock{ + container := consensusMocks.InitConsensusCore() + messenger := &consensusMocks.BroadcastMessengerMock{ BroadcastConsensusMessageCalled: func(message *consensus.Message) error { assert.Equal(t, expectedInvalidSigners, message.InvalidSigners) wasCalled = true @@ -1640,8 +1653,9 @@ func TestSubroundEndRound_CreateAndBroadcastInvalidSigners(t *testing.T) { }, } container.SetBroadcastMessenger(messenger) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.SetSelfPubKey("A") + sr.SetLeader("A") sr.CreateAndBroadcastInvalidSigners(expectedInvalidSigners) @@ -1657,7 +1671,7 @@ func TestGetFullMessagesForInvalidSigners(t *testing.T) { t.Run("empty p2p messages slice if not in state", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() messageSigningHandler := &mock.MessageSigningHandlerStub{ SerializeCalled: func(messages []p2p.MessageP2P) ([]byte, error) { @@ -1669,7 +1683,7 @@ func TestGetFullMessagesForInvalidSigners(t *testing.T) { container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) invalidSigners := []string{"B", "C"} invalidSignersBytes, err := sr.GetFullMessagesForInvalidSigners(invalidSigners) @@ -1680,7 +1694,7 @@ func TestGetFullMessagesForInvalidSigners(t *testing.T) { t.Run("should work", func(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() expectedInvalidSigners := []byte("expectedInvalidSigners") @@ -1694,7 +1708,7 @@ func TestGetFullMessagesForInvalidSigners(t *testing.T) { container.SetMessageSigningHandler(messageSigningHandler) - sr := *initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) sr.AddMessageWithSignature("B", &p2pmocks.P2PMessageMock{}) sr.AddMessageWithSignature("C", &p2pmocks.P2PMessageMock{}) @@ -1709,10 +1723,10 @@ func TestGetFullMessagesForInvalidSigners(t *testing.T) { func TestSubroundEndRound_getMinConsensusGroupIndexOfManagedKeys(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{} ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) sr, _ := spos.NewSubround( bls.SrSignature, bls.SrEndRound, @@ -1729,10 +1743,10 @@ func TestSubroundEndRound_getMinConsensusGroupIndexOfManagedKeys(t *testing.T) { &statusHandler.AppStatusHandlerStub{}, ) - srEndRound, _ := bls.NewSubroundEndRound( + srEndRound, _ := v1.NewSubroundEndRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, &statusHandler.AppStatusHandlerStub{}, &testscommon.SentSignatureTrackerStub{}, diff --git a/consensus/spos/bls/subroundSignature.go b/consensus/spos/bls/v1/subroundSignature.go similarity index 94% rename from consensus/spos/bls/subroundSignature.go rename to consensus/spos/bls/v1/subroundSignature.go index ac06cc72fdd..1d71ac59420 100644 --- a/consensus/spos/bls/subroundSignature.go +++ b/consensus/spos/bls/v1/subroundSignature.go @@ -1,4 +1,4 @@ -package bls +package v1 import ( "context" @@ -8,9 +8,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" ) type subroundSignature struct { @@ -60,7 +62,7 @@ func checkNewSubroundSignatureParams( if baseSubround == nil { return spos.ErrNilSubround } - if baseSubround.ConsensusState == nil { + if check.IfNil(baseSubround.ConsensusStateHandler) { return spos.ErrNilConsensusState } @@ -74,7 +76,7 @@ func (sr *subroundSignature) doSignatureJob(_ context.Context) bool { if !sr.CanDoSubroundJob(sr.Current()) { return false } - if check.IfNil(sr.Header) { + if check.IfNil(sr.GetHeader()) { log.Error("doSignatureJob", "error", spos.ErrNilHeader) return false } @@ -92,7 +94,7 @@ func (sr *subroundSignature) doSignatureJob(_ context.Context) bool { signatureShare, err := sr.SigningHandler().CreateSignatureShareForPublicKey( sr.GetData(), uint16(selfIndex), - sr.Header.GetEpoch(), + sr.GetHeader().GetEpoch(), []byte(sr.SelfPubKey()), ) if err != nil { @@ -125,7 +127,7 @@ func (sr *subroundSignature) createAndSendSignatureMessage(signatureShare []byte nil, pkBytes, nil, - int(MtSignature), + int(bls.MtSignature), sr.RoundHandler().Index(), sr.ChainID(), nil, @@ -236,7 +238,7 @@ func (sr *subroundSignature) receivedSignature(_ context.Context, cnsDta *consen // doSignatureConsensusCheck method checks if the consensus in the subround Signature is achieved func (sr *subroundSignature) doSignatureConsensusCheck() bool { - if sr.RoundCanceled { + if sr.GetRoundCanceled() { return false } @@ -250,7 +252,7 @@ func (sr *subroundSignature) doSignatureConsensusCheck() bool { isSelfInConsensusGroup := sr.IsNodeInConsensusGroup(sr.SelfPubKey()) || sr.IsMultiKeyInConsensusGroup() threshold := sr.Threshold(sr.Current()) - if sr.FallbackHeaderValidator().ShouldApplyFallbackValidation(sr.Header) { + if sr.FallbackHeaderValidator().ShouldApplyFallbackValidation(sr.GetHeader()) { threshold = sr.FallbackThreshold(sr.Current()) log.Warn("subroundSignature.doSignatureConsensusCheck: fallback validation has been applied", "minimum number of signatures required", threshold, @@ -261,7 +263,7 @@ func (sr *subroundSignature) doSignatureConsensusCheck() bool { areSignaturesCollected, numSigs := sr.areSignaturesCollected(threshold) areAllSignaturesCollected := numSigs == sr.ConsensusGroupSize() - isJobDoneByLeader := isSelfLeader && (areAllSignaturesCollected || (areSignaturesCollected && sr.WaitingAllSignaturesTimeOut)) + isJobDoneByLeader := isSelfLeader && (areAllSignaturesCollected || (areSignaturesCollected && sr.GetWaitingAllSignaturesTimeOut())) selfJobDone := true if sr.IsNodeInConsensusGroup(sr.SelfPubKey()) { @@ -332,7 +334,7 @@ func (sr *subroundSignature) waitAllSignatures() { return } - sr.WaitingAllSignaturesTimeOut = true + sr.SetWaitingAllSignaturesTimeOut(true) select { case sr.ConsensusChannel() <- true: @@ -352,12 +354,12 @@ func (sr *subroundSignature) doSignatureJobForManagedKeys() bool { isMultiKeyLeader := sr.IsMultiKeyLeaderInCurrentRound() numMultiKeysSignaturesSent := 0 - for idx, pk := range sr.ConsensusGroup() { + for _, pk := range sr.ConsensusGroup() { pkBytes := []byte(pk) if sr.IsJobDone(pk, sr.Current()) { continue } - if !sr.IsKeyManagedByCurrentNode(pkBytes) { + if !sr.IsKeyManagedBySelf(pkBytes) { continue } @@ -370,7 +372,7 @@ func (sr *subroundSignature) doSignatureJobForManagedKeys() bool { signatureShare, err := sr.SigningHandler().CreateSignatureShareForPublicKey( sr.GetData(), uint16(selfIndex), - sr.Header.GetEpoch(), + sr.GetHeader().GetEpoch(), pkBytes, ) if err != nil { @@ -387,8 +389,13 @@ func (sr *subroundSignature) doSignatureJobForManagedKeys() bool { numMultiKeysSignaturesSent++ } sr.sentSignatureTracker.SignatureSent(pkBytes) + leader, err := sr.GetLeader() + if err != nil { + log.Debug("doSignatureJobForManagedKeys.GetLeader", "error", err.Error()) + return false + } - isLeader := idx == spos.IndexOfLeaderInConsensusGroup + isLeader := pk == leader ok := sr.completeSignatureSubRound(pk, isLeader) if !ok { return false diff --git a/consensus/spos/bls/subroundSignature_test.go b/consensus/spos/bls/v1/subroundSignature_test.go similarity index 78% rename from consensus/spos/bls/subroundSignature_test.go rename to consensus/spos/bls/v1/subroundSignature_test.go index 9ee8a03ba19..a3708f8c326 100644 --- a/consensus/spos/bls/subroundSignature_test.go +++ b/consensus/spos/bls/v1/subroundSignature_test.go @@ -1,4 +1,4 @@ -package bls_test +package v1_test import ( "testing" @@ -6,19 +6,21 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus" - "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v1 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v1" "github.com/multiversx/mx-chain-go/testscommon" consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" ) -func initSubroundSignatureWithContainer(container *mock.ConsensusCoreMock) bls.SubroundSignature { - consensusState := initConsensusState() +func initSubroundSignatureWithContainer(container *spos.ConsensusCore) v1.SubroundSignature { + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -37,7 +39,7 @@ func initSubroundSignatureWithContainer(container *mock.ConsensusCoreMock) bls.S &statusHandler.AppStatusHandlerStub{}, ) - srSignature, _ := bls.NewSubroundSignature( + srSignature, _ := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -47,16 +49,16 @@ func initSubroundSignatureWithContainer(container *mock.ConsensusCoreMock) bls.S return srSignature } -func initSubroundSignature() bls.SubroundSignature { - container := mock.InitConsensusCore() +func initSubroundSignature() v1.SubroundSignature { + container := consensusMocks.InitConsensusCore() return initSubroundSignatureWithContainer(container) } func TestNewSubroundSignature(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -78,7 +80,7 @@ func TestNewSubroundSignature(t *testing.T) { t.Run("nil subround should error", func(t *testing.T) { t.Parallel() - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( nil, extend, &statusHandler.AppStatusHandlerStub{}, @@ -91,7 +93,7 @@ func TestNewSubroundSignature(t *testing.T) { t.Run("nil extend function handler should error", func(t *testing.T) { t.Parallel() - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, nil, &statusHandler.AppStatusHandlerStub{}, @@ -104,7 +106,7 @@ func TestNewSubroundSignature(t *testing.T) { t.Run("nil app status handler should error", func(t *testing.T) { t.Parallel() - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, nil, @@ -117,7 +119,7 @@ func TestNewSubroundSignature(t *testing.T) { t.Run("nil sent signatures tracker should error", func(t *testing.T) { t.Parallel() - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -125,15 +127,15 @@ func TestNewSubroundSignature(t *testing.T) { ) assert.Nil(t, srSignature) - assert.Equal(t, bls.ErrNilSentSignatureTracker, err) + assert.Equal(t, v1.ErrNilSentSignatureTracker, err) }) } func TestSubroundSignature_NewSubroundSignatureNilConsensusStateShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -152,8 +154,8 @@ func TestSubroundSignature_NewSubroundSignatureNilConsensusStateShouldFail(t *te &statusHandler.AppStatusHandlerStub{}, ) - sr.ConsensusState = nil - srSignature, err := bls.NewSubroundSignature( + sr.ConsensusStateHandler = nil + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -167,8 +169,8 @@ func TestSubroundSignature_NewSubroundSignatureNilConsensusStateShouldFail(t *te func TestSubroundSignature_NewSubroundSignatureNilHasherShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -187,7 +189,7 @@ func TestSubroundSignature_NewSubroundSignatureNilHasherShouldFail(t *testing.T) &statusHandler.AppStatusHandlerStub{}, ) container.SetHasher(nil) - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -201,8 +203,8 @@ func TestSubroundSignature_NewSubroundSignatureNilHasherShouldFail(t *testing.T) func TestSubroundSignature_NewSubroundSignatureNilMultiSignerContainerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -221,7 +223,7 @@ func TestSubroundSignature_NewSubroundSignatureNilMultiSignerContainerShouldFail &statusHandler.AppStatusHandlerStub{}, ) container.SetMultiSignerContainer(nil) - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -235,8 +237,8 @@ func TestSubroundSignature_NewSubroundSignatureNilMultiSignerContainerShouldFail func TestSubroundSignature_NewSubroundSignatureNilRoundHandlerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -256,7 +258,7 @@ func TestSubroundSignature_NewSubroundSignatureNilRoundHandlerShouldFail(t *test ) container.SetRoundHandler(nil) - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -270,8 +272,8 @@ func TestSubroundSignature_NewSubroundSignatureNilRoundHandlerShouldFail(t *test func TestSubroundSignature_NewSubroundSignatureNilSyncTimerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -290,7 +292,7 @@ func TestSubroundSignature_NewSubroundSignatureNilSyncTimerShouldFail(t *testing &statusHandler.AppStatusHandlerStub{}, ) container.SetSyncTimer(nil) - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -304,8 +306,8 @@ func TestSubroundSignature_NewSubroundSignatureNilSyncTimerShouldFail(t *testing func TestSubroundSignature_NewSubroundSignatureShouldWork(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := spos.NewSubround( @@ -324,7 +326,7 @@ func TestSubroundSignature_NewSubroundSignatureShouldWork(t *testing.T) { &statusHandler.AppStatusHandlerStub{}, ) - srSignature, err := bls.NewSubroundSignature( + srSignature, err := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -338,15 +340,15 @@ func TestSubroundSignature_NewSubroundSignatureShouldWork(t *testing.T) { func TestSubroundSignature_DoSignatureJob(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundSignatureWithContainer(container) + container := consensusMocks.InitConsensusCore() + sr := initSubroundSignatureWithContainer(container) - sr.Header = &block.Header{} - sr.Data = nil + sr.SetHeader(&block.Header{}) + sr.SetData(nil) r := sr.DoSignatureJob() assert.False(t, r) - sr.Data = []byte("X") + sr.SetData([]byte("X")) err := errors.New("create signature share error") signingHandler := &consensusMocks.SigningHandlerStub{ @@ -370,18 +372,21 @@ func TestSubroundSignature_DoSignatureJob(t *testing.T) { assert.True(t, r) _ = sr.SetJobDone(sr.SelfPubKey(), bls.SrSignature, false) - sr.RoundCanceled = false - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + sr.SetRoundCanceled(false) + leader, err := sr.GetLeader() + assert.Nil(t, err) + + sr.SetSelfPubKey(leader) r = sr.DoSignatureJob() assert.True(t, r) - assert.False(t, sr.RoundCanceled) + assert.False(t, sr.GetRoundCanceled()) } func TestSubroundSignature_DoSignatureJobWithMultikey(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusStateWithKeysHandler( + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusStateWithKeysHandler( &testscommon.KeysHandlerStub{ IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { return true @@ -407,7 +412,7 @@ func TestSubroundSignature_DoSignatureJobWithMultikey(t *testing.T) { ) signatureSentForPks := make(map[string]struct{}) - srSignature, _ := bls.NewSubroundSignature( + srSignature, _ := v1.NewSubroundSignature( sr, extend, &statusHandler.AppStatusHandlerStub{}, @@ -418,12 +423,12 @@ func TestSubroundSignature_DoSignatureJobWithMultikey(t *testing.T) { }, ) - srSignature.Header = &block.Header{} - srSignature.Data = nil + srSignature.SetHeader(&block.Header{}) + srSignature.SetData(nil) r := srSignature.DoSignatureJob() assert.False(t, r) - sr.Data = []byte("X") + sr.SetData([]byte("X")) err := errors.New("create signature share error") signingHandler := &consensusMocks.SigningHandlerStub{ @@ -447,11 +452,15 @@ func TestSubroundSignature_DoSignatureJobWithMultikey(t *testing.T) { assert.True(t, r) _ = sr.SetJobDone(sr.SelfPubKey(), bls.SrSignature, false) - sr.RoundCanceled = false - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + sr.SetRoundCanceled(false) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + + sr.SetSelfPubKey(leader) r = srSignature.DoSignatureJob() assert.True(t, r) - assert.False(t, sr.RoundCanceled) + assert.False(t, sr.GetRoundCanceled()) expectedMap := map[string]struct{}{ "A": {}, "B": {}, @@ -469,10 +478,10 @@ func TestSubroundSignature_DoSignatureJobWithMultikey(t *testing.T) { func TestSubroundSignature_ReceivedSignature(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() + sr := initSubroundSignature() signature := []byte("signature") cnsMsg := consensus.NewConsensusMessage( - sr.Data, + sr.GetData(), signature, nil, nil, @@ -488,20 +497,22 @@ func TestSubroundSignature_ReceivedSignature(t *testing.T) { nil, ) - sr.Header = &block.Header{} - sr.Data = nil + sr.SetHeader(&block.Header{}) + sr.SetData(nil) r := sr.ReceivedSignature(cnsMsg) assert.False(t, r) - sr.Data = []byte("Y") + sr.SetData([]byte("Y")) r = sr.ReceivedSignature(cnsMsg) assert.False(t, r) - sr.Data = []byte("X") + sr.SetData([]byte("X")) r = sr.ReceivedSignature(cnsMsg) assert.False(t, r) + leader, err := sr.GetLeader() + assert.Nil(t, err) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + sr.SetSelfPubKey(leader) cnsMsg.PubKey = []byte("X") r = sr.ReceivedSignature(cnsMsg) @@ -538,14 +549,14 @@ func TestSubroundSignature_ReceivedSignatureStoreShareFailed(t *testing.T) { }, } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetSigningHandler(signingHandler) - sr := *initSubroundSignatureWithContainer(container) - sr.Header = &block.Header{} + sr := initSubroundSignatureWithContainer(container) + sr.SetHeader(&block.Header{}) signature := []byte("signature") cnsMsg := consensus.NewConsensusMessage( - sr.Data, + sr.GetData(), signature, nil, nil, @@ -561,19 +572,21 @@ func TestSubroundSignature_ReceivedSignatureStoreShareFailed(t *testing.T) { nil, ) - sr.Data = nil + sr.SetData(nil) r := sr.ReceivedSignature(cnsMsg) assert.False(t, r) - sr.Data = []byte("Y") + sr.SetData([]byte("Y")) r = sr.ReceivedSignature(cnsMsg) assert.False(t, r) - sr.Data = []byte("X") + sr.SetData([]byte("X")) r = sr.ReceivedSignature(cnsMsg) assert.False(t, r) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) cnsMsg.PubKey = []byte("X") r = sr.ReceivedSignature(cnsMsg) @@ -599,7 +612,7 @@ func TestSubroundSignature_ReceivedSignatureStoreShareFailed(t *testing.T) { func TestSubroundSignature_SignaturesCollected(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() + sr := initSubroundSignature() for i := 0; i < len(sr.ConsensusGroup()); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrBlock, false) @@ -628,15 +641,15 @@ func TestSubroundSignature_SignaturesCollected(t *testing.T) { func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() - sr.RoundCanceled = true + sr := initSubroundSignature() + sr.SetRoundCanceled(true) assert.False(t, sr.DoSignatureConsensusCheck()) } func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenSubroundIsFinished(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() + sr := initSubroundSignature() sr.SetStatus(bls.SrSignature, spos.SsFinished) assert.True(t, sr.DoSignatureConsensusCheck()) } @@ -644,7 +657,7 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenSubround func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenSignaturesCollectedReturnTrue(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() + sr := initSubroundSignature() for i := 0; i < sr.Threshold(bls.SrSignature); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) @@ -656,18 +669,20 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenSignatur func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenSignaturesCollectedReturnFalse(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() + sr := initSubroundSignature() assert.False(t, sr.DoSignatureConsensusCheck()) } func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenNotAllSignaturesCollectedAndTimeIsNotOut(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundSignatureWithContainer(container) - sr.WaitingAllSignaturesTimeOut = false + container := consensusMocks.InitConsensusCore() + sr := initSubroundSignatureWithContainer(container) + sr.SetWaitingAllSignaturesTimeOut(false) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) for i := 0; i < sr.Threshold(bls.SrSignature); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) @@ -679,11 +694,13 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenNotAllS func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenAllSignaturesCollected(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundSignatureWithContainer(container) - sr.WaitingAllSignaturesTimeOut = false + container := consensusMocks.InitConsensusCore() + sr := initSubroundSignatureWithContainer(container) + sr.SetWaitingAllSignaturesTimeOut(false) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) for i := 0; i < sr.ConsensusGroupSize(); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) @@ -695,11 +712,13 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenAllSigna func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenEnoughButNotAllSignaturesCollectedAndTimeIsOut(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - sr := *initSubroundSignatureWithContainer(container) - sr.WaitingAllSignaturesTimeOut = true + container := consensusMocks.InitConsensusCore() + sr := initSubroundSignatureWithContainer(container) + sr.SetWaitingAllSignaturesTimeOut(true) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) for i := 0; i < sr.Threshold(bls.SrSignature); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) @@ -711,14 +730,14 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenEnoughBu func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenFallbackThresholdCouldNotBeApplied(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetFallbackHeaderValidator(&testscommon.FallBackHeaderValidatorStub{ ShouldApplyFallbackValidationCalled: func(headerHandler data.HeaderHandler) bool { return false }, }) - sr := *initSubroundSignatureWithContainer(container) - sr.WaitingAllSignaturesTimeOut = false + sr := initSubroundSignatureWithContainer(container) + sr.SetWaitingAllSignaturesTimeOut(false) sr.SetSelfPubKey(sr.ConsensusGroup()[0]) @@ -732,16 +751,18 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnFalseWhenFallbac func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenFallbackThresholdCouldBeApplied(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetFallbackHeaderValidator(&testscommon.FallBackHeaderValidatorStub{ ShouldApplyFallbackValidationCalled: func(headerHandler data.HeaderHandler) bool { return true }, }) - sr := *initSubroundSignatureWithContainer(container) - sr.WaitingAllSignaturesTimeOut = true + sr := initSubroundSignatureWithContainer(container) + sr.SetWaitingAllSignaturesTimeOut(true) - sr.SetSelfPubKey(sr.ConsensusGroup()[0]) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) for i := 0; i < sr.FallbackThreshold(bls.SrSignature); i++ { _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) @@ -753,14 +774,16 @@ func TestSubroundSignature_DoSignatureConsensusCheckShouldReturnTrueWhenFallback func TestSubroundSignature_ReceivedSignatureReturnFalseWhenConsensusDataIsNotEqual(t *testing.T) { t.Parallel() - sr := *initSubroundSignature() + sr := initSubroundSignature() + leader, err := sr.GetLeader() + assert.Nil(t, err) cnsMsg := consensus.NewConsensusMessage( - append(sr.Data, []byte("X")...), + append(sr.GetData(), []byte("X")...), []byte("signature"), nil, nil, - []byte(sr.ConsensusGroup()[0]), + []byte(leader), []byte("sig"), int(bls.MtSignature), 0, diff --git a/consensus/spos/bls/subroundStartRound.go b/consensus/spos/bls/v1/subroundStartRound.go similarity index 92% rename from consensus/spos/bls/subroundStartRound.go rename to consensus/spos/bls/v1/subroundStartRound.go index 571270dd774..a47d9235cd2 100644 --- a/consensus/spos/bls/subroundStartRound.go +++ b/consensus/spos/bls/v1/subroundStartRound.go @@ -1,4 +1,4 @@ -package bls +package v1 import ( "context" @@ -11,6 +11,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" outportcore "github.com/multiversx/mx-chain-core-go/data/outport" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/outport" @@ -80,7 +81,7 @@ func checkNewSubroundStartRoundParams( if baseSubround == nil { return spos.ErrNilSubround } - if baseSubround.ConsensusState == nil { + if check.IfNil(baseSubround.ConsensusStateHandler) { return spos.ErrNilConsensusState } @@ -105,8 +106,8 @@ func (sr *subroundStartRound) SetOutportHandler(outportHandler outport.OutportHa // doStartRoundJob method does the job of the subround StartRound func (sr *subroundStartRound) doStartRoundJob(_ context.Context) bool { sr.ResetConsensusState() - sr.RoundIndex = sr.RoundHandler().Index() - sr.RoundTimeStamp = sr.RoundHandler().TimeStamp() + sr.SetRoundIndex(sr.RoundHandler().Index()) + sr.SetRoundTimeStamp(sr.RoundHandler().TimeStamp()) topic := spos.GetConsensusTopicID(sr.ShardCoordinator()) sr.GetAntiFloodHandler().ResetForTopic(topic) sr.resetConsensusMessages() @@ -115,7 +116,7 @@ func (sr *subroundStartRound) doStartRoundJob(_ context.Context) bool { // doStartRoundConsensusCheck method checks if the consensus is achieved in the subround StartRound func (sr *subroundStartRound) doStartRoundConsensusCheck() bool { - if sr.RoundCanceled { + if sr.GetRoundCanceled() { return false } @@ -144,7 +145,7 @@ func (sr *subroundStartRound) initCurrentRound() bool { "round index", sr.RoundHandler().Index(), "error", err.Error()) - sr.RoundCanceled = true + sr.SetRoundCanceled(true) return false } @@ -163,13 +164,13 @@ func (sr *subroundStartRound) initCurrentRound() bool { if err != nil { log.Debug("initCurrentRound.GetLeader", "error", err.Error()) - sr.RoundCanceled = true + sr.SetRoundCanceled(true) return false } msg := "" - if sr.IsKeyManagedByCurrentNode([]byte(leader)) { + if sr.IsKeyManagedBySelf([]byte(leader)) { msg = " (my turn in multi-key)" } if leader == sr.SelfPubKey() && sr.ShouldConsiderSelfKeyInConsensus() { @@ -192,7 +193,7 @@ func (sr *subroundStartRound) initCurrentRound() bool { sr.indexRoundIfNeeded(pubKeys) isSingleKeyLeader := leader == sr.SelfPubKey() && sr.ShouldConsiderSelfKeyInConsensus() - isLeader := isSingleKeyLeader || sr.IsKeyManagedByCurrentNode([]byte(leader)) + isLeader := isSingleKeyLeader || sr.IsKeyManagedBySelf([]byte(leader)) isSelfInConsensus := sr.IsNodeInConsensusGroup(sr.SelfPubKey()) || numMultiKeysInConsensusGroup > 0 if !isSelfInConsensus { log.Debug("not in consensus group") @@ -208,19 +209,19 @@ func (sr *subroundStartRound) initCurrentRound() bool { if err != nil { log.Debug("initCurrentRound.Reset", "error", err.Error()) - sr.RoundCanceled = true + sr.SetRoundCanceled(true) return false } - startTime := sr.RoundTimeStamp + startTime := sr.GetRoundTimeStamp() maxTime := sr.RoundHandler().TimeDuration() * time.Duration(sr.processingThresholdPercentage) / 100 if sr.RoundHandler().RemainingTime(startTime, maxTime) < 0 { log.Debug("canceled round, time is out", "round", sr.SyncTimer().FormattedCurrentTime(), sr.RoundHandler().Index(), "subround", sr.Name()) - sr.RoundCanceled = true + sr.SetRoundCanceled(true) return false } @@ -237,7 +238,7 @@ func (sr *subroundStartRound) computeNumManagedKeysInConsensusGroup(pubKeys []st numMultiKeysInConsensusGroup := 0 for _, pk := range pubKeys { pkBytes := []byte(pk) - if sr.IsKeyManagedByCurrentNode(pkBytes) { + if sr.IsKeyManagedBySelf(pkBytes) { numMultiKeysInConsensusGroup++ log.Trace("in consensus group with multi key", "pk", core.GetTrimmedPk(hex.EncodeToString(pkBytes))) @@ -297,7 +298,7 @@ func (sr *subroundStartRound) indexRoundIfNeeded(pubKeys []string) { BlockWasProposed: false, ShardId: shardId, Epoch: epoch, - Timestamp: uint64(sr.RoundTimeStamp.Unix()), + Timestamp: uint64(sr.GetRoundTimeStamp().Unix()), } roundsInfo := &outportcore.RoundsInfo{ ShardID: shardId, @@ -322,9 +323,9 @@ func (sr *subroundStartRound) generateNextConsensusGroup(roundIndex int64) error shardId := sr.ShardCoordinator().SelfId() - nextConsensusGroup, err := sr.GetNextConsensusGroup( + leader, nextConsensusGroup, err := sr.GetNextConsensusGroup( randomSeed, - uint64(sr.RoundIndex), + uint64(sr.GetRoundIndex()), shardId, sr.NodesCoordinator(), currentHeader.GetEpoch(), @@ -341,6 +342,10 @@ func (sr *subroundStartRound) generateNextConsensusGroup(roundIndex int64) error } sr.SetConsensusGroup(nextConsensusGroup) + sr.SetLeader(leader) + + consensusGroupSizeForEpoch := sr.NodesCoordinator().ConsensusGroupSizeForShardAndEpoch(shardId, currentHeader.GetEpoch()) + sr.SetConsensusGroupSize(consensusGroupSizeForEpoch) return nil } @@ -369,5 +374,5 @@ func (sr *subroundStartRound) changeEpoch(currentEpoch uint32) { // NotifyOrder returns the notification order for a start of epoch event func (sr *subroundStartRound) NotifyOrder() uint32 { - return common.ConsensusOrder + return common.ConsensusStartRoundOrder } diff --git a/consensus/spos/bls/subroundStartRound_test.go b/consensus/spos/bls/v1/subroundStartRound_test.go similarity index 73% rename from consensus/spos/bls/subroundStartRound_test.go rename to consensus/spos/bls/v1/subroundStartRound_test.go index 2f5c21d2659..e5b898930f5 100644 --- a/consensus/spos/bls/subroundStartRound_test.go +++ b/consensus/spos/bls/v1/subroundStartRound_test.go @@ -1,26 +1,31 @@ -package bls_test +package v1_test import ( "errors" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v1 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v1" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/bootstrapperStubs" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) -func defaultSubroundStartRoundFromSubround(sr *spos.Subround) (bls.SubroundStartRound, error) { - startRound, err := bls.NewSubroundStartRound( +func defaultSubroundStartRoundFromSubround(sr *spos.Subround) (v1.SubroundStartRound, error) { + startRound, err := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, resetConsensusMessages, &testscommon.SentSignatureTrackerStub{}, @@ -29,11 +34,11 @@ func defaultSubroundStartRoundFromSubround(sr *spos.Subround) (bls.SubroundStart return startRound, err } -func defaultWithoutErrorSubroundStartRoundFromSubround(sr *spos.Subround) bls.SubroundStartRound { - startRound, _ := bls.NewSubroundStartRound( +func defaultWithoutErrorSubroundStartRoundFromSubround(sr *spos.Subround) v1.SubroundStartRound { + startRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, resetConsensusMessages, &testscommon.SentSignatureTrackerStub{}, @@ -65,14 +70,14 @@ func defaultSubround( ) } -func initSubroundStartRoundWithContainer(container spos.ConsensusCoreHandler) bls.SubroundStartRound { - consensusState := initConsensusState() +func initSubroundStartRoundWithContainer(container spos.ConsensusCoreHandler) v1.SubroundStartRound { + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) - srStartRound, _ := bls.NewSubroundStartRound( + srStartRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, resetConsensusMessages, &testscommon.SentSignatureTrackerStub{}, @@ -81,8 +86,8 @@ func initSubroundStartRoundWithContainer(container spos.ConsensusCoreHandler) bl return srStartRound } -func initSubroundStartRound() bls.SubroundStartRound { - container := mock.InitConsensusCore() +func initSubroundStartRound() v1.SubroundStartRound { + container := consensusMocks.InitConsensusCore() return initSubroundStartRoundWithContainer(container) } @@ -90,8 +95,8 @@ func TestNewSubroundStartRound(t *testing.T) { t.Parallel() ch := make(chan bool, 1) - consensusState := initConsensusState() - container := mock.InitConsensusCore() + consensusState := initializers.InitConsensusState() + container := consensusMocks.InitConsensusCore() sr, _ := spos.NewSubround( -1, bls.SrStartRound, @@ -111,10 +116,10 @@ func TestNewSubroundStartRound(t *testing.T) { t.Run("nil subround should error", func(t *testing.T) { t.Parallel() - srStartRound, err := bls.NewSubroundStartRound( + srStartRound, err := v1.NewSubroundStartRound( nil, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, resetConsensusMessages, &testscommon.SentSignatureTrackerStub{}, @@ -126,10 +131,10 @@ func TestNewSubroundStartRound(t *testing.T) { t.Run("nil extend function handler should error", func(t *testing.T) { t.Parallel() - srStartRound, err := bls.NewSubroundStartRound( + srStartRound, err := v1.NewSubroundStartRound( sr, nil, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, resetConsensusMessages, &testscommon.SentSignatureTrackerStub{}, @@ -142,10 +147,10 @@ func TestNewSubroundStartRound(t *testing.T) { t.Run("nil executeStoredMessages function handler should error", func(t *testing.T) { t.Parallel() - srStartRound, err := bls.NewSubroundStartRound( + srStartRound, err := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, nil, resetConsensusMessages, &testscommon.SentSignatureTrackerStub{}, @@ -158,10 +163,10 @@ func TestNewSubroundStartRound(t *testing.T) { t.Run("nil resetConsensusMessages function handler should error", func(t *testing.T) { t.Parallel() - srStartRound, err := bls.NewSubroundStartRound( + srStartRound, err := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, nil, &testscommon.SentSignatureTrackerStub{}, @@ -174,26 +179,26 @@ func TestNewSubroundStartRound(t *testing.T) { t.Run("nil sent signatures tracker should error", func(t *testing.T) { t.Parallel() - srStartRound, err := bls.NewSubroundStartRound( + srStartRound, err := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, executeStoredMessages, resetConsensusMessages, nil, ) assert.Nil(t, srStartRound) - assert.Equal(t, bls.ErrNilSentSignatureTracker, err) + assert.Equal(t, v1.ErrNilSentSignatureTracker, err) }) } func TestSubroundStartRound_NewSubroundStartRoundNilBlockChainShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) @@ -207,9 +212,9 @@ func TestSubroundStartRound_NewSubroundStartRoundNilBlockChainShouldFail(t *test func TestSubroundStartRound_NewSubroundStartRoundNilBootstrapperShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) @@ -223,13 +228,13 @@ func TestSubroundStartRound_NewSubroundStartRoundNilBootstrapperShouldFail(t *te func TestSubroundStartRound_NewSubroundStartRoundNilConsensusStateShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) - sr.ConsensusState = nil + sr.ConsensusStateHandler = nil srStartRound, err := defaultSubroundStartRoundFromSubround(sr) assert.Nil(t, srStartRound) @@ -239,9 +244,9 @@ func TestSubroundStartRound_NewSubroundStartRoundNilConsensusStateShouldFail(t * func TestSubroundStartRound_NewSubroundStartRoundNilMultiSignerContainerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) @@ -255,9 +260,9 @@ func TestSubroundStartRound_NewSubroundStartRoundNilMultiSignerContainerShouldFa func TestSubroundStartRound_NewSubroundStartRoundNilRoundHandlerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) @@ -271,9 +276,9 @@ func TestSubroundStartRound_NewSubroundStartRoundNilRoundHandlerShouldFail(t *te func TestSubroundStartRound_NewSubroundStartRoundNilSyncTimerShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) @@ -287,13 +292,13 @@ func TestSubroundStartRound_NewSubroundStartRoundNilSyncTimerShouldFail(t *testi func TestSubroundStartRound_NewSubroundStartRoundNilValidatorGroupSelectorShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) - container.SetValidatorGroupSelector(nil) + container.SetNodesCoordinator(nil) srStartRound, err := defaultSubroundStartRoundFromSubround(sr) assert.Nil(t, srStartRound) @@ -303,9 +308,9 @@ func TestSubroundStartRound_NewSubroundStartRoundNilValidatorGroupSelectorShould func TestSubroundStartRound_NewSubroundStartRoundShouldWork(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) @@ -319,14 +324,14 @@ func TestSubroundStartRound_NewSubroundStartRoundShouldWork(t *testing.T) { func TestSubroundStartRound_DoStartRoundShouldReturnTrue(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() - consensusState := initConsensusState() + consensusState := initializers.InitConsensusState() ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) - srStartRound := *defaultWithoutErrorSubroundStartRoundFromSubround(sr) + srStartRound := defaultWithoutErrorSubroundStartRoundFromSubround(sr) r := srStartRound.DoStartRoundJob() assert.True(t, r) @@ -335,9 +340,9 @@ func TestSubroundStartRound_DoStartRoundShouldReturnTrue(t *testing.T) { func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { t.Parallel() - sr := *initSubroundStartRound() + sr := initSubroundStartRound() - sr.RoundCanceled = true + sr.SetRoundCanceled(true) ok := sr.DoStartRoundConsensusCheck() assert.False(t, ok) @@ -346,7 +351,7 @@ func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnFalseWhenRound func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnTrueWhenRoundIsFinished(t *testing.T) { t.Parallel() - sr := *initSubroundStartRound() + sr := initSubroundStartRound() sr.SetStatus(bls.SrStartRound, spos.SsFinished) @@ -357,14 +362,14 @@ func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnTrueWhenRoundI func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnTrueWhenInitCurrentRoundReturnTrue(t *testing.T) { t.Parallel() - bootstrapperMock := &mock.BootstrapperStub{GetNodeStateCalled: func() common.NodeState { + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{GetNodeStateCalled: func() common.NodeState { return common.NsSynchronized }} - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetBootStrapper(bootstrapperMock) - sr := *initSubroundStartRoundWithContainer(container) + sr := initSubroundStartRoundWithContainer(container) sentTrackerInterface := sr.GetSentSignatureTracker() sentTracker := sentTrackerInterface.(*testscommon.SentSignatureTrackerStub) startRoundCalled := false @@ -380,15 +385,15 @@ func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnTrueWhenInitCu func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnFalseWhenInitCurrentRoundReturnFalse(t *testing.T) { t.Parallel() - bootstrapperMock := &mock.BootstrapperStub{GetNodeStateCalled: func() common.NodeState { + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{GetNodeStateCalled: func() common.NodeState { return common.NsNotSynchronized }} - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetBootStrapper(bootstrapperMock) container.SetRoundHandler(initRoundHandlerMock()) - sr := *initSubroundStartRoundWithContainer(container) + sr := initSubroundStartRoundWithContainer(container) ok := sr.DoStartRoundConsensusCheck() assert.False(t, ok) @@ -397,15 +402,15 @@ func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnFalseWhenInitC func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGetNodeStateNotReturnSynchronized(t *testing.T) { t.Parallel() - bootstrapperMock := &mock.BootstrapperStub{} + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} bootstrapperMock.GetNodeStateCalled = func() common.NodeState { return common.NsNotSynchronized } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetBootStrapper(bootstrapperMock) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) r := srStartRound.InitCurrentRound() assert.False(t, r) @@ -416,13 +421,13 @@ func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGenerateNextCon validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{} err := errors.New("error") - validatorGroupSelector.ComputeValidatorsGroupCalled = func(bytes []byte, round uint64, shardId uint32, epoch uint32) ([]nodesCoordinator.Validator, error) { - return nil, err + validatorGroupSelector.ComputeValidatorsGroupCalled = func(bytes []byte, round uint64, shardId uint32, epoch uint32) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + return nil, nil, err } - container := mock.InitConsensusCore() - container.SetValidatorGroupSelector(validatorGroupSelector) + container := consensusMocks.InitConsensusCore() + container.SetNodesCoordinator(validatorGroupSelector) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) r := srStartRound.InitCurrentRound() assert.False(t, r) @@ -436,10 +441,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldReturnTrueWhenMainMachineIsAct return true }, } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetNodeRedundancyHandler(nodeRedundancyMock) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) r := srStartRound.InitCurrentRound() assert.True(t, r) @@ -449,19 +454,24 @@ func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGetLeaderErr(t t.Parallel() validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{} + leader := &shardingMocks.ValidatorMock{PubKeyCalled: func() []byte { + return []byte("leader") + }} + validatorGroupSelector.ComputeValidatorsGroupCalled = func( bytes []byte, round uint64, shardId uint32, epoch uint32, - ) ([]nodesCoordinator.Validator, error) { - return make([]nodesCoordinator.Validator, 0), nil + ) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + // will cause an error in GetLeader because of empty consensus group + return leader, []nodesCoordinator.Validator{}, nil } - container := mock.InitConsensusCore() - container.SetValidatorGroupSelector(validatorGroupSelector) + container := consensusMocks.InitConsensusCore() + container.SetNodesCoordinator(validatorGroupSelector) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) r := srStartRound.InitCurrentRound() assert.False(t, r) @@ -470,14 +480,14 @@ func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGetLeaderErr(t func TestSubroundStartRound_InitCurrentRoundShouldReturnTrueWhenIsNotInTheConsensusGroup(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() - consensusState := initConsensusState() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() consensusState.SetSelfPubKey(consensusState.SelfPubKey() + "X") ch := make(chan bool, 1) sr, _ := defaultSubround(consensusState, ch, container) - srStartRound := *defaultWithoutErrorSubroundStartRoundFromSubround(sr) + srStartRound := defaultWithoutErrorSubroundStartRoundFromSubround(sr) r := srStartRound.InitCurrentRound() assert.True(t, r) @@ -492,10 +502,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenTimeIsOut(t *te return time.Duration(-1) } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetRoundHandler(roundHandlerMock) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) r := srStartRound.InitCurrentRound() assert.False(t, r) @@ -504,16 +514,16 @@ func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenTimeIsOut(t *te func TestSubroundStartRound_InitCurrentRoundShouldReturnTrue(t *testing.T) { t.Parallel() - bootstrapperMock := &mock.BootstrapperStub{} + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} bootstrapperMock.GetNodeStateCalled = func() common.NodeState { return common.NsSynchronized } - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() container.SetBootStrapper(bootstrapperMock) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) r := srStartRound.InitCurrentRound() assert.True(t, r) @@ -526,18 +536,18 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { t.Parallel() wasCalled := false - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{} appStatusHandler := &statusHandler.AppStatusHandlerStub{ SetStringValueHandler: func(key string, value string) { if key == common.MetricConsensusState { wasCalled = true - assert.Equal(t, value, "not in consensus group") + assert.Equal(t, "not in consensus group", value) } }, } ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) consensusState.SetSelfPubKey("not in consensus") sr, _ := spos.NewSubround( -1, @@ -555,10 +565,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { appStatusHandler, ) - srStartRound, _ := bls.NewSubroundStartRound( + srStartRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, executeStoredMessages, &testscommon.SentSignatureTrackerStub{}, @@ -571,7 +581,7 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { wasCalled := false wasIncrementCalled := false - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{ IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { return string(pkBytes) == "B" @@ -591,7 +601,7 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { }, } ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) consensusState.SetSelfPubKey("B") sr, _ := spos.NewSubround( -1, @@ -609,10 +619,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { appStatusHandler, ) - srStartRound, _ := bls.NewSubroundStartRound( + srStartRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, executeStoredMessages, &testscommon.SentSignatureTrackerStub{}, @@ -626,13 +636,13 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { wasCalled := false wasIncrementCalled := false - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{} appStatusHandler := &statusHandler.AppStatusHandlerStub{ SetStringValueHandler: func(key string, value string) { if key == common.MetricConsensusState { wasCalled = true - assert.Equal(t, value, "participant") + assert.Equal(t, "participant", value) } }, IncrementHandler: func(key string) { @@ -642,7 +652,8 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { }, } ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + consensusState.SetSelfPubKey("B") keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { return string(pkBytes) == consensusState.SelfPubKey() } @@ -662,10 +673,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { appStatusHandler, ) - srStartRound, _ := bls.NewSubroundStartRound( + srStartRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, executeStoredMessages, &testscommon.SentSignatureTrackerStub{}, @@ -680,21 +691,21 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { wasMetricConsensusStateCalled := false wasMetricCountLeaderCalled := false cntMetricConsensusRoundStateCalled := 0 - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{} appStatusHandler := &statusHandler.AppStatusHandlerStub{ SetStringValueHandler: func(key string, value string) { if key == common.MetricConsensusState { wasMetricConsensusStateCalled = true - assert.Equal(t, value, "proposer") + assert.Equal(t, "proposer", value) } if key == common.MetricConsensusRoundState { cntMetricConsensusRoundStateCalled++ switch cntMetricConsensusRoundStateCalled { case 1: - assert.Equal(t, value, "") + assert.Equal(t, "", value) case 2: - assert.Equal(t, value, "proposed") + assert.Equal(t, "proposed", value) default: assert.Fail(t, "should have been called only twice") } @@ -707,9 +718,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { }, } ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) leader, _ := consensusState.GetLeader() consensusState.SetSelfPubKey(leader) + sr, _ := spos.NewSubround( -1, bls.SrStartRound, @@ -726,10 +738,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { appStatusHandler, ) - srStartRound, _ := bls.NewSubroundStartRound( + srStartRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, executeStoredMessages, &testscommon.SentSignatureTrackerStub{}, @@ -745,21 +757,21 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { wasMetricConsensusStateCalled := false wasMetricCountLeaderCalled := false cntMetricConsensusRoundStateCalled := 0 - container := mock.InitConsensusCore() + container := consensusMocks.InitConsensusCore() keysHandler := &testscommon.KeysHandlerStub{} appStatusHandler := &statusHandler.AppStatusHandlerStub{ SetStringValueHandler: func(key string, value string) { if key == common.MetricConsensusState { wasMetricConsensusStateCalled = true - assert.Equal(t, value, "proposer") + assert.Equal(t, "proposer", value) } if key == common.MetricConsensusRoundState { cntMetricConsensusRoundStateCalled++ switch cntMetricConsensusRoundStateCalled { case 1: - assert.Equal(t, value, "") + assert.Equal(t, "", value) case 2: - assert.Equal(t, value, "proposed") + assert.Equal(t, "proposed", value) default: assert.Fail(t, "should have been called only twice") } @@ -772,7 +784,7 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { }, } ch := make(chan bool, 1) - consensusState := initConsensusStateWithKeysHandler(keysHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) leader, _ := consensusState.GetLeader() consensusState.SetSelfPubKey(leader) keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { @@ -794,10 +806,10 @@ func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { appStatusHandler, ) - srStartRound, _ := bls.NewSubroundStartRound( + srStartRound, _ := v1.NewSubroundStartRound( sr, extend, - bls.ProcessingThresholdPercent, + v1.ProcessingThresholdPercent, displayStatistics, executeStoredMessages, &testscommon.SentSignatureTrackerStub{}, @@ -820,13 +832,13 @@ func TestSubroundStartRound_GenerateNextConsensusGroupShouldReturnErr(t *testing round uint64, shardId uint32, epoch uint32, - ) ([]nodesCoordinator.Validator, error) { - return nil, err + ) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + return nil, nil, err } - container := mock.InitConsensusCore() - container.SetValidatorGroupSelector(validatorGroupSelector) + container := consensusMocks.InitConsensusCore() + container.SetNodesCoordinator(validatorGroupSelector) - srStartRound := *initSubroundStartRoundWithContainer(container) + srStartRound := initSubroundStartRoundWithContainer(container) err2 := srStartRound.GenerateNextConsensusGroup(0) diff --git a/consensus/spos/bls/v2/benchmark_test.go b/consensus/spos/bls/v2/benchmark_test.go new file mode 100644 index 00000000000..5fb3c3a253f --- /dev/null +++ b/consensus/spos/bls/v2/benchmark_test.go @@ -0,0 +1,138 @@ +package v2_test + +import ( + "context" + "sync" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data/block" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/signing" + "github.com/multiversx/mx-chain-crypto-go/signing/mcl" + mclMultiSig "github.com/multiversx/mx-chain-crypto-go/signing/mcl/multisig" + "github.com/multiversx/mx-chain-crypto-go/signing/multisig" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + cryptoFactory "github.com/multiversx/mx-chain-go/factory/crypto" + "github.com/multiversx/mx-chain-go/testscommon" + nodeMock "github.com/multiversx/mx-chain-go/testscommon/common" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +func BenchmarkSubroundSignature_doSignatureJobForManagedKeys63(b *testing.B) { + benchmarkSubroundSignatureDoSignatureJobForManagedKeys(b, 63) +} + +func BenchmarkSubroundSignature_doSignatureJobForManagedKeys400(b *testing.B) { + benchmarkSubroundSignatureDoSignatureJobForManagedKeys(b, 400) +} + +func createMultiSignerSetup(grSize uint16, suite crypto.Suite) (crypto.KeyGenerator, map[string]crypto.PrivateKey) { + kg := signing.NewKeyGenerator(suite) + mapKeys := make(map[string]crypto.PrivateKey) + + for i := uint16(0); i < grSize; i++ { + sk, pk := kg.GeneratePair() + + pubKey, _ := pk.ToByteArray() + mapKeys[string(pubKey)] = sk + } + return kg, mapKeys +} + +func benchmarkSubroundSignatureDoSignatureJobForManagedKeys(b *testing.B, numberOfKeys int) { + container := consensus.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + llSigner := &mclMultiSig.BlsMultiSignerKOSK{} + + suite := mcl.NewSuiteBLS12() + kg, mapKeys := createMultiSignerSetup(uint16(numberOfKeys), suite) + + multiSigHandler, _ := multisig.NewBLSMultisig(llSigner, kg) + + keysHandlerMock := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + GetHandledPrivateKeyCalled: func(pkBytes []byte) crypto.PrivateKey { + return mapKeys[string(pkBytes)] + }, + } + + args := cryptoFactory.ArgsSigningHandler{ + PubKeys: initializers.CreateEligibleListFromMap(mapKeys), + MultiSignerContainer: &cryptoMocks.MultiSignerContainerStub{ + GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { + return multiSigHandler, nil + }}, + SingleSigner: &cryptoMocks.SingleSignerStub{}, + KeyGenerator: kg, + KeysHandler: keysHandlerMock, + } + signingHandler, err := cryptoFactory.NewSigningHandler(args) + require.Nil(b, err) + + container.SetSigningHandler(signingHandler) + consensusState := initializers.InitConsensusStateWithArgs(keysHandlerMock, mapKeys) + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + signatureSentForPks := make(map[string]struct{}) + mutex := sync.Mutex{} + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{ + SignatureSentCalled: func(pkBytes []byte) { + mutex.Lock() + signatureSentForPks[string(pkBytes)] = struct{}{} + mutex.Unlock() + }, + }, + &consensus.SposWorkerMock{}, + &nodeMock.ThrottlerStub{}, + ) + + sr.SetHeader(&block.Header{}) + sr.SetSelfPubKey("OTHER") + + b.ResetTimer() + b.StopTimer() + + for i := 0; i < b.N; i++ { + b.StartTimer() + r := srSignature.DoSignatureJobForManagedKeys(context.TODO()) + b.StopTimer() + + require.True(b, r) + } +} diff --git a/consensus/spos/bls/v2/benchmark_verify_signatures_test.go b/consensus/spos/bls/v2/benchmark_verify_signatures_test.go new file mode 100644 index 00000000000..46d18d8460e --- /dev/null +++ b/consensus/spos/bls/v2/benchmark_verify_signatures_test.go @@ -0,0 +1,123 @@ +package v2_test + +import ( + "context" + "sort" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-crypto-go/signing" + "github.com/multiversx/mx-chain-crypto-go/signing/mcl" + "github.com/stretchr/testify/require" + + crypto "github.com/multiversx/mx-chain-crypto-go" + mclMultisig "github.com/multiversx/mx-chain-crypto-go/signing/mcl/multisig" + "github.com/multiversx/mx-chain-crypto-go/signing/multisig" + + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/dataRetriever/mock" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + + "github.com/multiversx/mx-chain-go/common" + factoryCrypto "github.com/multiversx/mx-chain-go/factory/crypto" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +const benchmarkKeyPairsCardinal = 400 + +// createListFromMapKeys make a predictable iteration on keys from a map of keys +func createListFromMapKeys(mapKeys map[string]crypto.PrivateKey) []string { + keys := make([]string, 0, len(mapKeys)) + + for key := range mapKeys { + keys = append(keys, key) + } + + sort.Strings(keys) + + return keys +} + +// generateKeyPairs generates benchmarkKeyPairsCardinal number of pairs(public key & private key) +func generateKeyPairs(kg crypto.KeyGenerator) map[string]crypto.PrivateKey { + mapKeys := make(map[string]crypto.PrivateKey) + + for i := uint16(0); i < benchmarkKeyPairsCardinal; i++ { + sk, pk := kg.GeneratePair() + + pubKey, _ := pk.ToByteArray() + mapKeys[string(pubKey)] = sk + } + return mapKeys +} + +// BenchmarkSubroundEndRound_VerifyNodesOnAggSigFailTime measure time needed to verify signatures +func BenchmarkSubroundEndRound_VerifyNodesOnAggSigFailTime(b *testing.B) { + + b.ResetTimer() + b.StopTimer() + ctx, cancel := context.WithCancel(context.TODO()) + + defer func() { + cancel() + }() + + container := consensus.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + llSigner := &mclMultisig.BlsMultiSignerKOSK{} + suite := mcl.NewSuiteBLS12() + kg := signing.NewKeyGenerator(suite) + + multiSigHandler, _ := multisig.NewBLSMultisig(llSigner, kg) + + mapKeys := generateKeyPairs(kg) + + keysHandlerMock := &testscommon.KeysHandlerStub{ + GetHandledPrivateKeyCalled: func(pkBytes []byte) crypto.PrivateKey { + return mapKeys[string(pkBytes)] + }, + } + keys := createListFromMapKeys(mapKeys) + args := factoryCrypto.ArgsSigningHandler{ + PubKeys: keys, + MultiSignerContainer: &cryptoMocks.MultiSignerContainerStub{ + GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { + return multiSigHandler, nil + }, + }, + SingleSigner: &cryptoMocks.SingleSignerStub{}, + KeyGenerator: kg, + KeysHandler: keysHandlerMock, + } + + signingHandler, err := factoryCrypto.NewSigningHandler(args) + require.Nil(b, err) + + container.SetSigningHandler(signingHandler) + consensusState := initializers.InitConsensusStateWithArgsVerifySignature(keysHandlerMock, keys) + dataToBeSigned := []byte("message") + consensusState.Data = dataToBeSigned + + sr := initSubroundEndRoundWithContainerAndConsensusState(container, &statusHandler.AppStatusHandlerStub{}, consensusState, &dataRetrieverMocks.ThrottlerStub{}) + for i := 0; i < len(sr.ConsensusGroup()); i++ { + _, err := sr.SigningHandler().CreateSignatureShareForPublicKey(dataToBeSigned, uint16(i), sr.EnableEpochsHandler().GetCurrentEpoch(), []byte(keys[i])) + require.Nil(b, err) + _ = sr.SetJobDone(keys[i], bls.SrSignature, true) + } + for i := 0; i < b.N; i++ { + b.StartTimer() + invalidSigners, err := sr.VerifyNodesOnAggSigFail(ctx) + b.StopTimer() + require.Nil(b, err) + require.NotNil(b, invalidSigners) + } +} diff --git a/consensus/spos/bls/v2/blsSubroundsFactory.go b/consensus/spos/bls/v2/blsSubroundsFactory.go new file mode 100644 index 00000000000..a2b35dcefe3 --- /dev/null +++ b/consensus/spos/bls/v2/blsSubroundsFactory.go @@ -0,0 +1,315 @@ +package v2 + +import ( + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + "github.com/multiversx/mx-chain-go/outport" +) + +// factory defines the data needed by this factory to create all the subrounds and give them their specific +// functionality +type factory struct { + consensusCore spos.ConsensusCoreHandler + consensusState spos.ConsensusStateHandler + worker spos.WorkerHandler + + appStatusHandler core.AppStatusHandler + outportHandler outport.OutportHandler + sentSignaturesTracker spos.SentSignaturesTracker + chainID []byte + currentPid core.PeerID + signatureThrottler core.Throttler +} + +// NewSubroundsFactory creates a new consensusState object +func NewSubroundsFactory( + consensusDataContainer spos.ConsensusCoreHandler, + consensusState spos.ConsensusStateHandler, + worker spos.WorkerHandler, + chainID []byte, + currentPid core.PeerID, + appStatusHandler core.AppStatusHandler, + sentSignaturesTracker spos.SentSignaturesTracker, + signatureThrottler core.Throttler, + outportHandler outport.OutportHandler, +) (*factory, error) { + // no need to check the outport handler, it can be nil + err := checkNewFactoryParams( + consensusDataContainer, + consensusState, + worker, + chainID, + appStatusHandler, + sentSignaturesTracker, + signatureThrottler, + ) + if err != nil { + return nil, err + } + + fct := factory{ + consensusCore: consensusDataContainer, + consensusState: consensusState, + worker: worker, + appStatusHandler: appStatusHandler, + chainID: chainID, + currentPid: currentPid, + sentSignaturesTracker: sentSignaturesTracker, + signatureThrottler: signatureThrottler, + outportHandler: outportHandler, + } + + return &fct, nil +} + +func checkNewFactoryParams( + container spos.ConsensusCoreHandler, + state spos.ConsensusStateHandler, + worker spos.WorkerHandler, + chainID []byte, + appStatusHandler core.AppStatusHandler, + sentSignaturesTracker spos.SentSignaturesTracker, + signatureThrottler core.Throttler, +) error { + err := spos.ValidateConsensusCore(container) + if err != nil { + return err + } + if state == nil { + return spos.ErrNilConsensusState + } + if check.IfNil(worker) { + return spos.ErrNilWorker + } + if check.IfNil(appStatusHandler) { + return spos.ErrNilAppStatusHandler + } + if check.IfNil(sentSignaturesTracker) { + return ErrNilSentSignatureTracker + } + if check.IfNil(signatureThrottler) { + return spos.ErrNilThrottler + } + if len(chainID) == 0 { + return spos.ErrInvalidChainID + } + + return nil +} + +// SetOutportHandler method will update the value of the factory's outport +func (fct *factory) SetOutportHandler(driver outport.OutportHandler) { + fct.outportHandler = driver +} + +// GenerateSubrounds will generate the subrounds used in BLS Cns +func (fct *factory) GenerateSubrounds(epoch uint32) error { + fct.initConsensusThreshold(epoch) + fct.consensusCore.Chronology().RemoveAllSubrounds() + fct.worker.RemoveAllReceivedMessagesCalls() + fct.worker.RemoveAllReceivedHeaderHandlers() + + err := fct.generateStartRoundSubround() + if err != nil { + return err + } + + err = fct.generateBlockSubround() + if err != nil { + return err + } + + err = fct.generateSignatureSubround() + if err != nil { + return err + } + + err = fct.generateEndRoundSubround() + if err != nil { + return err + } + + return nil +} + +func (fct *factory) getTimeDuration() time.Duration { + return fct.consensusCore.RoundHandler().TimeDuration() +} + +func (fct *factory) generateStartRoundSubround() error { + subround, err := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(float64(fct.getTimeDuration())*srStartStartTime), + int64(float64(fct.getTimeDuration())*srStartEndTime), + bls.GetSubroundName(bls.SrStartRound), + fct.consensusState, + fct.worker.GetConsensusStateChangedChannel(), + fct.worker.ExecuteStoredMessages, + fct.consensusCore, + fct.chainID, + fct.currentPid, + fct.appStatusHandler, + ) + if err != nil { + return err + } + + subroundStartRoundInstance, err := NewSubroundStartRound( + subround, + processingThresholdPercent, + fct.sentSignaturesTracker, + fct.worker, + ) + if err != nil { + return err + } + + err = subroundStartRoundInstance.SetOutportHandler(fct.outportHandler) + if err != nil { + return err + } + + fct.consensusCore.Chronology().AddSubround(subroundStartRoundInstance) + + return nil +} + +func (fct *factory) generateBlockSubround() error { + subround, err := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(float64(fct.getTimeDuration())*srBlockStartTime), + int64(float64(fct.getTimeDuration())*srBlockEndTime), + bls.GetSubroundName(bls.SrBlock), + fct.consensusState, + fct.worker.GetConsensusStateChangedChannel(), + fct.worker.ExecuteStoredMessages, + fct.consensusCore, + fct.chainID, + fct.currentPid, + fct.appStatusHandler, + ) + if err != nil { + return err + } + + subroundBlockInstance, err := NewSubroundBlock( + subround, + processingThresholdPercent, + fct.worker, + ) + if err != nil { + return err + } + + fct.worker.AddReceivedMessageCall(bls.MtBlockBody, subroundBlockInstance.receivedBlockBody) + fct.worker.AddReceivedHeaderHandler(subroundBlockInstance.receivedBlockHeader) + fct.consensusCore.Chronology().AddSubround(subroundBlockInstance) + + return nil +} + +func (fct *factory) generateSignatureSubround() error { + subround, err := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(float64(fct.getTimeDuration())*srSignatureStartTime), + int64(float64(fct.getTimeDuration())*srSignatureEndTime), + bls.GetSubroundName(bls.SrSignature), + fct.consensusState, + fct.worker.GetConsensusStateChangedChannel(), + fct.worker.ExecuteStoredMessages, + fct.consensusCore, + fct.chainID, + fct.currentPid, + fct.appStatusHandler, + ) + if err != nil { + return err + } + + subroundSignatureObject, err := NewSubroundSignature( + subround, + fct.appStatusHandler, + fct.sentSignaturesTracker, + fct.worker, + fct.signatureThrottler, + ) + if err != nil { + return err + } + + fct.consensusCore.Chronology().AddSubround(subroundSignatureObject) + + return nil +} + +func (fct *factory) generateEndRoundSubround() error { + subround, err := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(float64(fct.getTimeDuration())*srEndStartTime), + int64(float64(fct.getTimeDuration())*srEndEndTime), + bls.GetSubroundName(bls.SrEndRound), + fct.consensusState, + fct.worker.GetConsensusStateChangedChannel(), + fct.worker.ExecuteStoredMessages, + fct.consensusCore, + fct.chainID, + fct.currentPid, + fct.appStatusHandler, + ) + if err != nil { + return err + } + + subroundEndRoundObject, err := NewSubroundEndRound( + subround, + spos.MaxThresholdPercent, + fct.appStatusHandler, + fct.sentSignaturesTracker, + fct.worker, + fct.signatureThrottler, + ) + if err != nil { + return err + } + + fct.worker.AddReceivedProofHandler(subroundEndRoundObject.receivedProof) + fct.worker.AddReceivedMessageCall(bls.MtInvalidSigners, subroundEndRoundObject.receivedInvalidSignersInfo) + fct.worker.AddReceivedMessageCall(bls.MtSignature, subroundEndRoundObject.receivedSignature) + fct.consensusCore.Chronology().AddSubround(subroundEndRoundObject) + + return nil +} + +func (fct *factory) initConsensusThreshold(epoch uint32) { + consensusGroupSizeForEpoch := fct.consensusCore.NodesCoordinator().ConsensusGroupSizeForShardAndEpoch(fct.consensusCore.ShardCoordinator().SelfId(), epoch) + pBFTThreshold := core.GetPBFTThreshold(consensusGroupSizeForEpoch) + pBFTFallbackThreshold := core.GetPBFTFallbackThreshold(consensusGroupSizeForEpoch) + fct.consensusState.SetThreshold(bls.SrBlock, 1) + fct.consensusState.SetThreshold(bls.SrSignature, pBFTThreshold) + fct.consensusState.SetFallbackThreshold(bls.SrBlock, 1) + fct.consensusState.SetFallbackThreshold(bls.SrSignature, pBFTFallbackThreshold) + + log.Debug("initConsensusThreshold updating thresholds", + "epoch", epoch, + "pBFTThreshold", pBFTThreshold, + "pBFTFallbackThreshold", pBFTFallbackThreshold, + ) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (fct *factory) IsInterfaceNil() bool { + return fct == nil +} diff --git a/consensus/spos/bls/v2/blsSubroundsFactory_test.go b/consensus/spos/bls/v2/blsSubroundsFactory_test.go new file mode 100644 index 00000000000..5be8b6bdbcd --- /dev/null +++ b/consensus/spos/bls/v2/blsSubroundsFactory_test.go @@ -0,0 +1,692 @@ +package v2_test + +import ( + "context" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/dataRetriever/mock" + "github.com/multiversx/mx-chain-go/outport" + "github.com/multiversx/mx-chain-go/testscommon" + testscommonConsensus "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + testscommonOutport "github.com/multiversx/mx-chain-go/testscommon/outport" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +var chainID = []byte("chain ID") + +const currentPid = core.PeerID("pid") + +const roundTimeDuration = 100 * time.Millisecond + +// executeStoredMessages tries to execute all the messages received which are valid for execution +func executeStoredMessages() { +} + +func initRoundHandlerMock() *testscommonConsensus.RoundHandlerMock { + return &testscommonConsensus.RoundHandlerMock{ + RoundIndex: 0, + TimeStampCalled: func() time.Time { + return time.Unix(0, 0) + }, + TimeDurationCalled: func() time.Duration { + return roundTimeDuration + }, + } +} + +func initWorker() spos.WorkerHandler { + sposWorker := &testscommonConsensus.SposWorkerMock{} + sposWorker.GetConsensusStateChangedChannelsCalled = func() chan bool { + return make(chan bool) + } + sposWorker.RemoveAllReceivedMessagesCallsCalled = func() {} + + sposWorker.AddReceivedMessageCallCalled = + func(messageType consensus.MessageType, receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool) { + } + + return sposWorker +} + +func initFactoryWithContainer(container *spos.ConsensusCore) v2.Factory { + worker := initWorker() + consensusState := initializers.InitConsensusState() + + fct, _ := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + return fct +} + +func initFactory() v2.Factory { + container := testscommonConsensus.InitConsensusCore() + return initFactoryWithContainer(container) +} + +func TestFactory_GetMessageTypeName(t *testing.T) { + t.Parallel() + + r := bls.GetStringValue(bls.MtBlockBodyAndHeader) + assert.Equal(t, "(BLOCK_BODY_AND_HEADER)", r) + + r = bls.GetStringValue(bls.MtBlockBody) + assert.Equal(t, "(BLOCK_BODY)", r) + + r = bls.GetStringValue(bls.MtBlockHeader) + assert.Equal(t, "(BLOCK_HEADER)", r) + + r = bls.GetStringValue(bls.MtSignature) + assert.Equal(t, "(SIGNATURE)", r) + + r = bls.GetStringValue(bls.MtBlockHeaderFinalInfo) + assert.Equal(t, "(FINAL_INFO)", r) + + r = bls.GetStringValue(bls.MtUnknown) + assert.Equal(t, "(UNKNOWN)", r) + + r = bls.GetStringValue(consensus.MessageType(-1)) + assert.Equal(t, "Undefined message type", r) +} + +func TestFactory_NewFactoryNilContainerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + worker := initWorker() + + fct, err := v2.NewSubroundsFactory( + nil, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilConsensusCore, err) +} + +func TestFactory_NewFactoryNilConsensusStateShouldFail(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + + fct, err := v2.NewSubroundsFactory( + container, + nil, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilConsensusState, err) +} + +func TestFactory_NewFactoryNilBlockchainShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetBlockchain(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilBlockChain, err) +} + +func TestFactory_NewFactoryNilBlockProcessorShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetBlockProcessor(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilBlockProcessor, err) +} + +func TestFactory_NewFactoryNilBootstrapperShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetBootStrapper(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilBootstrapper, err) +} + +func TestFactory_NewFactoryNilChronologyHandlerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetChronology(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilChronologyHandler, err) +} + +func TestFactory_NewFactoryNilHasherShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetHasher(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilHasher, err) +} + +func TestFactory_NewFactoryNilMarshalizerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetMarshalizer(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilMarshalizer, err) +} + +func TestFactory_NewFactoryNilMultiSignerContainerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetMultiSignerContainer(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilMultiSignerContainer, err) +} + +func TestFactory_NewFactoryNilRoundHandlerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetRoundHandler(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilRoundHandler, err) +} + +func TestFactory_NewFactoryNilShardCoordinatorShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetShardCoordinator(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilShardCoordinator, err) +} + +func TestFactory_NewFactoryNilSyncTimerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetSyncTimer(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestFactory_NewFactoryNilValidatorGroupSelectorShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + container.SetNodesCoordinator(nil) + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilNodesCoordinator, err) +} + +func TestFactory_NewFactoryNilWorkerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + nil, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilWorker, err) +} + +func TestFactory_NewFactoryNilAppStatusHandlerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + nil, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilAppStatusHandler, err) +} + +func TestFactory_NewFactoryNilSignaturesTrackerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + nil, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, v2.ErrNilSentSignatureTracker, err) +} + +func TestFactory_NewFactoryNilThrottlerShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + nil, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrNilThrottler, err) +} + +func TestFactory_NewFactoryShouldWork(t *testing.T) { + t.Parallel() + + fct := *initFactory() + + assert.False(t, check.IfNil(&fct)) +} + +func TestFactory_NewFactoryEmptyChainIDShouldFail(t *testing.T) { + t.Parallel() + + consensusState := initializers.InitConsensusState() + container := testscommonConsensus.InitConsensusCore() + worker := initWorker() + + fct, err := v2.NewSubroundsFactory( + container, + consensusState, + worker, + nil, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &dataRetrieverMocks.ThrottlerStub{}, + nil, + ) + + assert.Nil(t, fct) + assert.Equal(t, spos.ErrInvalidChainID, err) +} + +func TestFactory_GenerateSubroundStartRoundShouldFailWhenNewSubroundFail(t *testing.T) { + t.Parallel() + + fct := *initFactory() + fct.Worker().(*testscommonConsensus.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + return nil + } + + err := fct.GenerateStartRoundSubround() + + assert.Equal(t, spos.ErrNilChannel, err) +} + +func TestFactory_GenerateSubroundStartRoundShouldFailWhenNewSubroundStartRoundFail(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + fct := *initFactoryWithContainer(container) + container.SetSyncTimer(nil) + + err := fct.GenerateStartRoundSubround() + + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestFactory_GenerateSubroundBlockShouldFailWhenNewSubroundFail(t *testing.T) { + t.Parallel() + + fct := *initFactory() + fct.Worker().(*testscommonConsensus.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + return nil + } + + err := fct.GenerateBlockSubround() + + assert.Equal(t, spos.ErrNilChannel, err) +} + +func TestFactory_GenerateSubroundBlockShouldFailWhenNewSubroundBlockFail(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + fct := *initFactoryWithContainer(container) + container.SetSyncTimer(nil) + + err := fct.GenerateBlockSubround() + + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestFactory_GenerateSubroundSignatureShouldFailWhenNewSubroundFail(t *testing.T) { + t.Parallel() + + fct := *initFactory() + fct.Worker().(*testscommonConsensus.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + return nil + } + + err := fct.GenerateSignatureSubround() + + assert.Equal(t, spos.ErrNilChannel, err) +} + +func TestFactory_GenerateSubroundSignatureShouldFailWhenNewSubroundSignatureFail(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + fct := *initFactoryWithContainer(container) + container.SetSyncTimer(nil) + + err := fct.GenerateSignatureSubround() + + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestFactory_GenerateSubroundEndRoundShouldFailWhenNewSubroundFail(t *testing.T) { + t.Parallel() + + fct := *initFactory() + fct.Worker().(*testscommonConsensus.SposWorkerMock).GetConsensusStateChangedChannelsCalled = func() chan bool { + return nil + } + + err := fct.GenerateEndRoundSubround() + + assert.Equal(t, spos.ErrNilChannel, err) +} + +func TestFactory_GenerateSubroundEndRoundShouldFailWhenNewSubroundEndRoundFail(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + fct := *initFactoryWithContainer(container) + container.SetSyncTimer(nil) + + err := fct.GenerateEndRoundSubround() + + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestFactory_GenerateSubroundsShouldWork(t *testing.T) { + t.Parallel() + + subroundHandlers := 0 + + chrm := &testscommonConsensus.ChronologyHandlerMock{} + chrm.AddSubroundCalled = func(subroundHandler consensus.SubroundHandler) { + subroundHandlers++ + } + container := testscommonConsensus.InitConsensusCore() + container.SetChronology(chrm) + providedEpoch := uint32(123) + wasConsensusGroupSizeCalled := false + container.SetNodesCoordinator(&shardingMocks.NodesCoordinatorMock{ + ConsensusGroupSizeCalled: func(shard uint32, epoch uint32) int { + wasConsensusGroupSizeCalled = true + require.Equal(t, providedEpoch, epoch) + return 1 + }, + }) + fct := *initFactoryWithContainer(container) + fct.SetOutportHandler(&testscommonOutport.OutportStub{}) + + err := fct.GenerateSubrounds(providedEpoch) + assert.Nil(t, err) + require.True(t, wasConsensusGroupSizeCalled) + + assert.Equal(t, 4, subroundHandlers) +} + +func TestFactory_GenerateSubroundsNilOutportShouldFail(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + fct := *initFactoryWithContainer(container) + + err := fct.GenerateSubrounds(0) + assert.Equal(t, outport.ErrNilDriver, err) +} + +func TestFactory_SetIndexerShouldWork(t *testing.T) { + t.Parallel() + + container := testscommonConsensus.InitConsensusCore() + fct := *initFactoryWithContainer(container) + + outportHandler := &testscommonOutport.OutportStub{} + fct.SetOutportHandler(outportHandler) + + assert.Equal(t, outportHandler, fct.Outport()) +} diff --git a/consensus/spos/bls/v2/constants.go b/consensus/spos/bls/v2/constants.go new file mode 100644 index 00000000000..93856652b39 --- /dev/null +++ b/consensus/spos/bls/v2/constants.go @@ -0,0 +1,37 @@ +package v2 + +import ( + logger "github.com/multiversx/mx-chain-logger-go" +) + +var log = logger.GetOrCreate("consensus/spos/bls/v2") + +// waitingAllSigsMaxTimeThreshold specifies the max allocated time for waiting all signatures from the total time of the subround signature +const waitingAllSigsMaxTimeThreshold = 0.5 + +// processingThresholdPercent specifies the max allocated time for processing the block as a percentage of the total time of the round +const processingThresholdPercent = 85 + +// srStartStartTime specifies the start time, from the total time of the round, of Subround Start +const srStartStartTime = 0.0 + +// srEndStartTime specifies the end time, from the total time of the round, of Subround Start +const srStartEndTime = 0.05 + +// srBlockStartTime specifies the start time, from the total time of the round, of Subround Block +const srBlockStartTime = 0.05 + +// srBlockEndTime specifies the end time, from the total time of the round, of Subround Block +const srBlockEndTime = 0.25 + +// srSignatureStartTime specifies the start time, from the total time of the round, of Subround Signature +const srSignatureStartTime = 0.25 + +// srSignatureEndTime specifies the end time, from the total time of the round, of Subround Signature +const srSignatureEndTime = 0.85 + +// srEndStartTime specifies the start time, from the total time of the round, of Subround End +const srEndStartTime = 0.85 + +// srEndEndTime specifies the end time, from the total time of the round, of Subround End +const srEndEndTime = 0.95 diff --git a/consensus/spos/bls/v2/errors.go b/consensus/spos/bls/v2/errors.go new file mode 100644 index 00000000000..545e84ac760 --- /dev/null +++ b/consensus/spos/bls/v2/errors.go @@ -0,0 +1,12 @@ +package v2 + +import "errors" + +// ErrNilSentSignatureTracker defines the error for setting a nil SentSignatureTracker +var ErrNilSentSignatureTracker = errors.New("nil sent signature tracker") + +// ErrTimeOut signals that the time is out +var ErrTimeOut = errors.New("time is out") + +// ErrProofAlreadyPropagated signals that the proof was already propagated +var ErrProofAlreadyPropagated = errors.New("proof already propagated") diff --git a/consensus/spos/bls/v2/export_test.go b/consensus/spos/bls/v2/export_test.go new file mode 100644 index 00000000000..61193eb6193 --- /dev/null +++ b/consensus/spos/bls/v2/export_test.go @@ -0,0 +1,362 @@ +package v2 + +import ( + "context" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + + cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/ntp" + "github.com/multiversx/mx-chain-go/outport" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/sharding" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" +) + +const ProcessingThresholdPercent = processingThresholdPercent + +// factory + +// Factory defines a type for the factory structure +type Factory *factory + +// BlockChain gets the chain handler object +func (fct *factory) BlockChain() data.ChainHandler { + return fct.consensusCore.Blockchain() +} + +// BlockProcessor gets the block processor object +func (fct *factory) BlockProcessor() process.BlockProcessor { + return fct.consensusCore.BlockProcessor() +} + +// Bootstrapper gets the bootstrapper object +func (fct *factory) Bootstrapper() process.Bootstrapper { + return fct.consensusCore.BootStrapper() +} + +// ChronologyHandler gets the chronology handler object +func (fct *factory) ChronologyHandler() consensus.ChronologyHandler { + return fct.consensusCore.Chronology() +} + +// ConsensusState gets the consensus state struct pointer +func (fct *factory) ConsensusState() spos.ConsensusStateHandler { + return fct.consensusState +} + +// Hasher gets the hasher object +func (fct *factory) Hasher() hashing.Hasher { + return fct.consensusCore.Hasher() +} + +// Marshalizer gets the marshalizer object +func (fct *factory) Marshalizer() marshal.Marshalizer { + return fct.consensusCore.Marshalizer() +} + +// MultiSigner gets the multi signer object +func (fct *factory) MultiSignerContainer() cryptoCommon.MultiSignerContainer { + return fct.consensusCore.MultiSignerContainer() +} + +// RoundHandler gets the roundHandler object +func (fct *factory) RoundHandler() consensus.RoundHandler { + return fct.consensusCore.RoundHandler() +} + +// ShardCoordinator gets the shard coordinator object +func (fct *factory) ShardCoordinator() sharding.Coordinator { + return fct.consensusCore.ShardCoordinator() +} + +// SyncTimer gets the sync timer object +func (fct *factory) SyncTimer() ntp.SyncTimer { + return fct.consensusCore.SyncTimer() +} + +// NodesCoordinator gets the nodes coordinator object +func (fct *factory) NodesCoordinator() nodesCoordinator.NodesCoordinator { + return fct.consensusCore.NodesCoordinator() +} + +// Worker gets the worker object +func (fct *factory) Worker() spos.WorkerHandler { + return fct.worker +} + +// SetWorker sets the worker object +func (fct *factory) SetWorker(worker spos.WorkerHandler) { + fct.worker = worker +} + +// GenerateStartRoundSubround generates the instance of subround StartRound and added it to the chronology subrounds list +func (fct *factory) GenerateStartRoundSubround() error { + return fct.generateStartRoundSubround() +} + +// GenerateBlockSubround generates the instance of subround Block and added it to the chronology subrounds list +func (fct *factory) GenerateBlockSubround() error { + return fct.generateBlockSubround() +} + +// GenerateSignatureSubround generates the instance of subround Signature and added it to the chronology subrounds list +func (fct *factory) GenerateSignatureSubround() error { + return fct.generateSignatureSubround() +} + +// GenerateEndRoundSubround generates the instance of subround EndRound and added it to the chronology subrounds list +func (fct *factory) GenerateEndRoundSubround() error { + return fct.generateEndRoundSubround() +} + +// AppStatusHandler gets the app status handler object +func (fct *factory) AppStatusHandler() core.AppStatusHandler { + return fct.appStatusHandler +} + +// Outport gets the outport object +func (fct *factory) Outport() outport.OutportHandler { + return fct.outportHandler +} + +// subroundStartRound + +// SubroundStartRound defines an alias for the subroundStartRound structure +type SubroundStartRound = *subroundStartRound + +// DoStartRoundJob method does the job of the subround StartRound +func (sr *subroundStartRound) DoStartRoundJob() bool { + return sr.doStartRoundJob(context.Background()) +} + +// DoStartRoundConsensusCheck method checks if the consensus is achieved in the subround StartRound +func (sr *subroundStartRound) DoStartRoundConsensusCheck() bool { + return sr.doStartRoundConsensusCheck() +} + +// GenerateNextConsensusGroup generates the next consensu group based on current (random seed, shard id and round) +func (sr *subroundStartRound) GenerateNextConsensusGroup(roundIndex int64) error { + return sr.generateNextConsensusGroup(roundIndex) +} + +// InitCurrentRound inits all the stuff needed in the current round +func (sr *subroundStartRound) InitCurrentRound() bool { + return sr.initCurrentRound() +} + +// GetSentSignatureTracker returns the subroundStartRound's SentSignaturesTracker instance +func (sr *subroundStartRound) GetSentSignatureTracker() spos.SentSignaturesTracker { + return sr.sentSignatureTracker +} + +// subroundBlock + +// SubroundBlock defines an alias for the subroundBlock structure +type SubroundBlock = *subroundBlock + +// Blockchain gets the ChainHandler stored in the ConsensusCore +func (sr *subroundBlock) BlockChain() data.ChainHandler { + return sr.Blockchain() +} + +// DoBlockJob method does the job of the subround Block +func (sr *subroundBlock) DoBlockJob() bool { + return sr.doBlockJob(context.Background()) +} + +// ProcessReceivedBlock method processes the received proposed block in the subround Block +func (sr *subroundBlock) ProcessReceivedBlock(cnsDta *consensus.Message) bool { + return sr.processReceivedBlock(context.Background(), cnsDta.RoundIndex, cnsDta.PubKey) +} + +// DoBlockConsensusCheck method checks if the consensus in the subround Block is achieved +func (sr *subroundBlock) DoBlockConsensusCheck() bool { + return sr.doBlockConsensusCheck() +} + +// IsBlockReceived method checks if the block was received from the leader in the current round +func (sr *subroundBlock) IsBlockReceived(threshold int) bool { + return sr.isBlockReceived(threshold) +} + +// CreateHeader method creates the proposed block header in the subround Block +func (sr *subroundBlock) CreateHeader() (data.HeaderHandler, error) { + return sr.createHeader() +} + +// CreateBody method creates the proposed block body in the subround Block +func (sr *subroundBlock) CreateBlock(hdr data.HeaderHandler) (data.HeaderHandler, data.BodyHandler, error) { + return sr.createBlock(hdr) +} + +// SendBlockBody method sends the proposed block body in the subround Block +func (sr *subroundBlock) SendBlockBody(body data.BodyHandler, marshalizedBody []byte) bool { + return sr.sendBlockBody(body, marshalizedBody) +} + +// SendBlockHeader method sends the proposed block header in the subround Block +func (sr *subroundBlock) SendBlockHeader(header data.HeaderHandler, marshalizedHeader []byte) bool { + return sr.sendBlockHeader(header, marshalizedHeader) +} + +// ComputeSubroundProcessingMetric computes processing metric related to the subround Block +func (sr *subroundBlock) ComputeSubroundProcessingMetric(startTime time.Time, metric string) { + sr.computeSubroundProcessingMetric(startTime, metric) +} + +// ReceivedBlockBody method is called when a block body is received through the block body channel +func (sr *subroundBlock) ReceivedBlockBody(cnsDta *consensus.Message) bool { + return sr.receivedBlockBody(context.Background(), cnsDta) +} + +// ReceivedBlockHeader method is called when a block header is received through the block header channel +func (sr *subroundBlock) ReceivedBlockHeader(header data.HeaderHandler) { + sr.receivedBlockHeader(header) +} + +// GetLeaderForHeader returns the leader based on header info +func (sr *subroundBlock) GetLeaderForHeader(headerHandler data.HeaderHandler) ([]byte, error) { + return sr.getLeaderForHeader(headerHandler) +} + +// subroundSignature + +// SubroundSignature defines an alias to the subroundSignature structure +type SubroundSignature = *subroundSignature + +// DoSignatureJob method does the job of the subround Signature +func (sr *subroundSignature) DoSignatureJob() bool { + return sr.doSignatureJob(context.Background()) +} + +// DoSignatureConsensusCheck method checks if the consensus in the subround Signature is achieved +func (sr *subroundSignature) DoSignatureConsensusCheck() bool { + return sr.doSignatureConsensusCheck() +} + +// subroundEndRound + +// SubroundEndRound defines a type for the subroundEndRound structure +type SubroundEndRound = *subroundEndRound + +// DoEndRoundJob method does the job of the subround EndRound +func (sr *subroundEndRound) DoEndRoundJob() bool { + return sr.doEndRoundJob(context.Background()) +} + +// DoEndRoundConsensusCheck method checks if the consensus is achieved +func (sr *subroundEndRound) DoEndRoundConsensusCheck() bool { + return sr.doEndRoundConsensusCheck() +} + +// CheckSignaturesValidity method checks the signature validity for the nodes included in bitmap +func (sr *subroundEndRound) CheckSignaturesValidity(bitmap []byte) error { + return sr.checkSignaturesValidity(bitmap) +} + +// DoEndRoundJobByLeader calls the unexported doEndRoundJobByNode function +func (sr *subroundEndRound) DoEndRoundJobByNode() bool { + return sr.doEndRoundJobByNode() +} + +// CreateAndBroadcastProof calls the unexported createAndBroadcastHeaderFinalInfo function +func (sr *subroundEndRound) CreateAndBroadcastProof(signature []byte, bitmap []byte) { + _ = sr.createAndBroadcastProof(signature, bitmap, "sender") +} + +// ReceivedProof calls the unexported receivedProof function +func (sr *subroundEndRound) ReceivedProof(proof consensus.ProofHandler) { + sr.receivedProof(proof) +} + +// IsOutOfTime calls the unexported isOutOfTime function +func (sr *subroundEndRound) IsOutOfTime() bool { + return sr.isOutOfTime() +} + +// VerifyNodesOnAggSigFail calls the unexported verifyNodesOnAggSigFail function +func (sr *subroundEndRound) VerifyNodesOnAggSigFail(ctx context.Context) ([]string, error) { + return sr.verifyNodesOnAggSigFail(ctx) +} + +// ComputeAggSigOnValidNodes calls the unexported computeAggSigOnValidNodes function +func (sr *subroundEndRound) ComputeAggSigOnValidNodes() ([]byte, []byte, error) { + return sr.computeAggSigOnValidNodes() +} + +// ReceivedInvalidSignersInfo calls the unexported receivedInvalidSignersInfo function +func (sr *subroundEndRound) ReceivedInvalidSignersInfo(cnsDta *consensus.Message) bool { + return sr.receivedInvalidSignersInfo(context.Background(), cnsDta) +} + +// VerifyInvalidSigners calls the unexported verifyInvalidSigners function +func (sr *subroundEndRound) VerifyInvalidSigners(invalidSigners []byte) ([]string, error) { + return sr.verifyInvalidSigners(invalidSigners) +} + +// GetMinConsensusGroupIndexOfManagedKeys calls the unexported getMinConsensusGroupIndexOfManagedKeys function +func (sr *subroundEndRound) GetMinConsensusGroupIndexOfManagedKeys() int { + return sr.getMinConsensusGroupIndexOfManagedKeys() +} + +// CreateAndBroadcastInvalidSigners calls the unexported createAndBroadcastInvalidSigners function +func (sr *subroundEndRound) CreateAndBroadcastInvalidSigners(invalidSigners []byte) { + sr.createAndBroadcastInvalidSigners(invalidSigners, nil, "sender") +} + +// GetFullMessagesForInvalidSigners calls the unexported getFullMessagesForInvalidSigners function +func (sr *subroundEndRound) GetFullMessagesForInvalidSigners(invalidPubKeys []string) ([]byte, error) { + return sr.getFullMessagesForInvalidSigners(invalidPubKeys) +} + +// GetSentSignatureTracker returns the subroundEndRound's SentSignaturesTracker instance +func (sr *subroundEndRound) GetSentSignatureTracker() spos.SentSignaturesTracker { + return sr.sentSignatureTracker +} + +// ChangeEpoch calls the unexported changeEpoch function +func (sr *subroundStartRound) ChangeEpoch(epoch uint32) { + sr.changeEpoch(epoch) +} + +// IndexRoundIfNeeded calls the unexported indexRoundIfNeeded function +func (sr *subroundStartRound) IndexRoundIfNeeded(pubKeys []string) { + sr.indexRoundIfNeeded(pubKeys) +} + +// SendSignatureForManagedKey calls the unexported sendSignatureForManagedKey function +func (sr *subroundSignature) SendSignatureForManagedKey(idx int, pk string) bool { + return sr.sendSignatureForManagedKey(idx, pk) +} + +// DoSignatureJobForManagedKeys calls the unexported doSignatureJobForManagedKeys function +func (sr *subroundSignature) DoSignatureJobForManagedKeys(ctx context.Context) bool { + return sr.doSignatureJobForManagedKeys(ctx) +} + +// ReceivedSignature method is called when a signature is received through the signature channel +func (sr *subroundEndRound) ReceivedSignature(cnsDta *consensus.Message) bool { + return sr.receivedSignature(context.Background(), cnsDta) +} + +// WaitForProof - +func (sr *subroundEndRound) WaitForProof() bool { + return sr.waitForProof() +} + +// GetEquivalentProofSender - +func (sr *subroundEndRound) GetEquivalentProofSender() string { + return sr.getEquivalentProofSender() +} + +// SendProof - +func (sr *subroundEndRound) SendProof() (bool, error) { + return sr.sendProof() +} diff --git a/consensus/spos/bls/v2/subroundBlock.go b/consensus/spos/bls/v2/subroundBlock.go new file mode 100644 index 00000000000..8dfe90fb678 --- /dev/null +++ b/consensus/spos/bls/v2/subroundBlock.go @@ -0,0 +1,728 @@ +package v2 + +import ( + "bytes" + "context" + "sync" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" +) + +// maxAllowedSizeInBytes defines how many bytes are allowed as payload in a message +const maxAllowedSizeInBytes = uint32(core.MegabyteSize * 95 / 100) + +// subroundBlock defines the data needed by the subround Block +type subroundBlock struct { + *spos.Subround + + processingThresholdPercentage int + worker spos.WorkerHandler + mutBlockProcessing sync.Mutex +} + +// NewSubroundBlock creates a subroundBlock object +func NewSubroundBlock( + baseSubround *spos.Subround, + processingThresholdPercentage int, + worker spos.WorkerHandler, +) (*subroundBlock, error) { + err := checkNewSubroundBlockParams(baseSubround) + if err != nil { + return nil, err + } + + if check.IfNil(worker) { + return nil, spos.ErrNilWorker + } + + srBlock := subroundBlock{ + Subround: baseSubround, + processingThresholdPercentage: processingThresholdPercentage, + worker: worker, + } + + srBlock.Job = srBlock.doBlockJob + srBlock.Check = srBlock.doBlockConsensusCheck + srBlock.Extend = srBlock.worker.Extend + + return &srBlock, nil +} + +func checkNewSubroundBlockParams( + baseSubround *spos.Subround, +) error { + if baseSubround == nil { + return spos.ErrNilSubround + } + + if check.IfNil(baseSubround.ConsensusStateHandler) { + return spos.ErrNilConsensusState + } + + err := spos.ValidateConsensusCore(baseSubround.ConsensusCoreHandler) + + return err +} + +// doBlockJob method does the job of the subround Block +func (sr *subroundBlock) doBlockJob(ctx context.Context) bool { + if !sr.IsSelfLeader() { // is NOT self leader in this round? + return false + } + + if sr.RoundHandler().Index() <= sr.getRoundInLastCommittedBlock() { + return false + } + + if sr.IsLeaderJobDone(sr.Current()) { + return false + } + + if sr.IsSubroundFinished(sr.Current()) { + return false + } + + metricStatTime := time.Now() + defer sr.computeSubroundProcessingMetric(metricStatTime, common.MetricCreatedProposedBlock) + + header, err := sr.createHeader() + if err != nil { + printLogMessage(ctx, "doBlockJob.createHeader", err) + return false + } + + header, body, err := sr.createBlock(header) + if err != nil { + printLogMessage(ctx, "doBlockJob.createBlock", err) + return false + } + + // block proof verification should be done over the header that contains the leader signature + leaderSignature, err := sr.signBlockHeader(header) + if err != nil { + printLogMessage(ctx, "doBlockJob.signBlockHeader", err) + return false + } + + err = header.SetLeaderSignature(leaderSignature) + if err != nil { + printLogMessage(ctx, "doBlockJob.SetLeaderSignature", err) + return false + } + + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + log.Debug("doBlockJob.GetLeader", "error", errGetLeader) + return false + } + + sentWithSuccess := sr.sendBlock(header, body, leader) + if !sentWithSuccess { + return false + } + + err = sr.SetJobDone(leader, sr.Current(), true) + if err != nil { + log.Debug("doBlockJob.SetSelfJobDone", "error", err.Error()) + return false + } + + // placeholder for subroundBlock.doBlockJob script + + sr.ConsensusCoreHandler.ScheduledProcessor().StartScheduledProcessing(header, body, sr.GetRoundTimeStamp()) + + return true +} + +func (sr *subroundBlock) signBlockHeader(header data.HeaderHandler) ([]byte, error) { + headerClone := header.ShallowClone() + err := headerClone.SetLeaderSignature(nil) + if err != nil { + return nil, err + } + + marshalledHdr, err := sr.Marshalizer().Marshal(headerClone) + if err != nil { + return nil, err + } + + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + return nil, errGetLeader + } + + return sr.SigningHandler().CreateSignatureForPublicKey(marshalledHdr, []byte(leader)) +} + +func printLogMessage(ctx context.Context, baseMessage string, err error) { + if common.IsContextDone(ctx) { + log.Debug(baseMessage + " context is closing") + return + } + + log.Debug(baseMessage, "error", err.Error()) +} + +func (sr *subroundBlock) sendBlock(header data.HeaderHandler, body data.BodyHandler, _ string) bool { + marshalledBody, err := sr.Marshalizer().Marshal(body) + if err != nil { + log.Debug("sendBlock.Marshal: body", "error", err.Error()) + return false + } + + marshalledHeader, err := sr.Marshalizer().Marshal(header) + if err != nil { + log.Debug("sendBlock.Marshal: header", "error", err.Error()) + return false + } + + sr.logBlockSize(marshalledBody, marshalledHeader) + if !sr.sendBlockBody(body, marshalledBody) || !sr.sendBlockHeader(header, marshalledHeader) { + return false + } + + return true +} + +func (sr *subroundBlock) logBlockSize(marshalledBody []byte, marshalledHeader []byte) { + bodyAndHeaderSize := uint32(len(marshalledBody) + len(marshalledHeader)) + log.Debug("logBlockSize", + "body size", len(marshalledBody), + "header size", len(marshalledHeader), + "body and header size", bodyAndHeaderSize, + "max allowed size in bytes", maxAllowedSizeInBytes) +} + +func (sr *subroundBlock) createBlock(header data.HeaderHandler) (data.HeaderHandler, data.BodyHandler, error) { + startTime := sr.GetRoundTimeStamp() + maxTime := time.Duration(sr.EndTime()) + haveTimeInCurrentSubround := func() bool { + return sr.RoundHandler().RemainingTime(startTime, maxTime) > 0 + } + + finalHeader, blockBody, err := sr.BlockProcessor().CreateBlock( + header, + haveTimeInCurrentSubround, + ) + if err != nil { + return nil, nil, err + } + + return finalHeader, blockBody, nil +} + +// sendBlockBody method sends the proposed block body in the subround Block +func (sr *subroundBlock) sendBlockBody( + bodyHandler data.BodyHandler, + marshalizedBody []byte, +) bool { + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + log.Debug("sendBlockBody.GetLeader", "error", errGetLeader) + return false + } + + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + marshalizedBody, + nil, + []byte(leader), + nil, + int(bls.MtBlockBody), + sr.RoundHandler().Index(), + sr.ChainID(), + nil, + nil, + nil, + sr.GetAssociatedPid([]byte(leader)), + nil, + ) + + err := sr.BroadcastMessenger().BroadcastConsensusMessage(cnsMsg) + if err != nil { + log.Debug("sendBlockBody.BroadcastConsensusMessage", "error", err.Error()) + return false + } + + log.Debug("step 1: block body has been sent") + + sr.SetBody(bodyHandler) + + return true +} + +// sendBlockHeader method sends the proposed block header in the subround Block +func (sr *subroundBlock) sendBlockHeader( + headerHandler data.HeaderHandler, + marshalledHeader []byte, +) bool { + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + log.Debug("sendBlockHeader.GetLeader", "error", errGetLeader) + return false + } + + err := sr.BroadcastMessenger().BroadcastHeader(headerHandler, []byte(leader)) + if err != nil { + log.Warn("sendBlockHeader.BroadcastHeader", "error", err.Error()) + return false + } + + headerHash := sr.Hasher().Compute(string(marshalledHeader)) + + log.Debug("step 1: block header has been sent", + "nonce", headerHandler.GetNonce(), + "hash", headerHash) + + sr.SetData(headerHash) + sr.SetHeader(headerHandler) + + return true +} + +func (sr *subroundBlock) getPrevHeaderAndHash() (data.HeaderHandler, []byte) { + prevHeader := sr.Blockchain().GetCurrentBlockHeader() + prevHeaderHash := sr.Blockchain().GetCurrentBlockHeaderHash() + if check.IfNil(prevHeader) { + prevHeader = sr.Blockchain().GetGenesisHeader() + prevHeaderHash = sr.Blockchain().GetGenesisHeaderHash() + } + + return prevHeader, prevHeaderHash +} + +func (sr *subroundBlock) createHeader() (data.HeaderHandler, error) { + prevHeader, prevHash := sr.getPrevHeaderAndHash() + nonce := prevHeader.GetNonce() + 1 + prevRandSeed := prevHeader.GetRandSeed() + + round := uint64(sr.RoundHandler().Index()) + hdr, err := sr.BlockProcessor().CreateNewHeader(round, nonce) + if err != nil { + return nil, err + } + + err = hdr.SetPrevHash(prevHash) + if err != nil { + return nil, err + } + + leader, errGetLeader := sr.GetLeader() + if errGetLeader != nil { + return nil, errGetLeader + } + + randSeed, err := sr.SigningHandler().CreateSignatureForPublicKey(prevRandSeed, []byte(leader)) + if err != nil { + return nil, err + } + + err = hdr.SetShardID(sr.ShardCoordinator().SelfId()) + if err != nil { + return nil, err + } + + err = hdr.SetTimeStamp(uint64(sr.RoundHandler().TimeStamp().Unix())) + if err != nil { + return nil, err + } + + err = hdr.SetPrevRandSeed(prevRandSeed) + if err != nil { + return nil, err + } + + err = hdr.SetRandSeed(randSeed) + if err != nil { + return nil, err + } + + err = hdr.SetChainID(sr.ChainID()) + if err != nil { + return nil, err + } + + return hdr, nil +} + +// receivedBlockBody method is called when a block body is received through the block body channel +func (sr *subroundBlock) receivedBlockBody(ctx context.Context, cnsDta *consensus.Message) bool { + node := string(cnsDta.PubKey) + + if !sr.IsNodeLeaderInCurrentRound(node) { // is NOT this node leader in current round? + sr.PeerHonestyHandler().ChangeScore( + node, + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.LeaderPeerHonestyDecreaseFactor, + ) + + return false + } + + if sr.IsBlockBodyAlreadyReceived() { + return false + } + + if !sr.CanProcessReceivedMessage(cnsDta, sr.RoundHandler().Index(), sr.Current()) { + return false + } + + sr.SetBody(sr.BlockProcessor().DecodeBlockBody(cnsDta.Body)) + + if check.IfNil(sr.GetBody()) { + return false + } + + log.Debug("step 1: block body has been received") + + blockProcessedWithSuccess := sr.processReceivedBlock(ctx, cnsDta.RoundIndex, cnsDta.PubKey) + + sr.PeerHonestyHandler().ChangeScore( + node, + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.LeaderPeerHonestyIncreaseFactor, + ) + + return blockProcessedWithSuccess +} + +func (sr *subroundBlock) isHeaderForCurrentConsensus(header data.HeaderHandler) bool { + if check.IfNil(header) { + return false + } + if header.GetShardID() != sr.ShardCoordinator().SelfId() { + return false + } + if header.GetRound() != uint64(sr.RoundHandler().Index()) { + return false + } + + prevHeader, prevHash := sr.getPrevHeaderAndHash() + if check.IfNil(prevHeader) { + return false + } + if !bytes.Equal(header.GetPrevHash(), prevHash) { + return false + } + if header.GetNonce() != prevHeader.GetNonce()+1 { + return false + } + prevRandSeed := prevHeader.GetRandSeed() + + return bytes.Equal(header.GetPrevRandSeed(), prevRandSeed) +} + +func (sr *subroundBlock) getLeaderForHeader(headerHandler data.HeaderHandler) ([]byte, error) { + nc := sr.NodesCoordinator() + + prevBlockEpoch := uint32(0) + if sr.Blockchain().GetCurrentBlockHeader() != nil { + prevBlockEpoch = sr.Blockchain().GetCurrentBlockHeader().GetEpoch() + } + // TODO: remove this if first block in new epoch will be validated by epoch validators + // first block in epoch is validated by previous epoch validators + selectionEpoch := headerHandler.GetEpoch() + if selectionEpoch != prevBlockEpoch { + selectionEpoch = prevBlockEpoch + } + leader, _, err := nc.ComputeConsensusGroup( + headerHandler.GetPrevRandSeed(), + headerHandler.GetRound(), + headerHandler.GetShardID(), + selectionEpoch, + ) + if err != nil { + return nil, err + } + + return leader.PubKey(), err +} + +func (sr *subroundBlock) receivedBlockHeader(headerHandler data.HeaderHandler) { + if check.IfNil(headerHandler) { + return + } + + log.Debug("subroundBlock.receivedBlockHeader", "nonce", headerHandler.GetNonce(), "round", headerHandler.GetRound()) + if headerHandler.CheckFieldsForNil() != nil { + return + } + + isHeaderForCurrentConsensus := sr.isHeaderForCurrentConsensus(headerHandler) + if !isHeaderForCurrentConsensus { + log.Debug("subroundBlock.receivedBlockHeader - header is not for current consensus") + return + } + + isLeader := sr.IsSelfLeader() + if sr.ConsensusGroup() == nil || isLeader { + log.Debug("subroundBlock.receivedBlockHeader - consensus group is nil or is leader") + return + } + + if sr.IsConsensusDataSet() { + log.Debug("subroundBlock.receivedBlockHeader - consensus data is set") + return + } + + headerLeader, err := sr.getLeaderForHeader(headerHandler) + if err != nil { + log.Debug("subroundBlock.receivedBlockHeader - error getting leader for header", err.Error()) + return + } + + if !sr.IsNodeLeaderInCurrentRound(string(headerLeader)) { + sr.PeerHonestyHandler().ChangeScore( + string(headerLeader), + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.LeaderPeerHonestyDecreaseFactor, + ) + + log.Debug("subroundBlock.receivedBlockHeader - leader is not the leader in current round") + return + } + + if sr.IsHeaderAlreadyReceived() { + log.Debug("subroundBlock.receivedBlockHeader - header is already received") + return + } + + if !sr.CanProcessReceivedHeader(string(headerLeader)) { + log.Debug("subroundBlock.receivedBlockHeader - can not process received header") + return + } + + headerHash, err := core.CalculateHash(sr.Marshalizer(), sr.Hasher(), headerHandler) + if err != nil { + log.Debug("subroundBlock.receivedBlockHeader", "error", err.Error()) + return + } + + sr.SetData(headerHash) + sr.SetHeader(headerHandler) + + log.Debug("step 1: block header has been received", + "nonce", sr.GetHeader().GetNonce(), + "hash", sr.GetData()) + + sr.AddReceivedHeader(headerHandler) + + ctx, cancel := context.WithTimeout(context.Background(), sr.RoundHandler().TimeDuration()) + defer cancel() + + _ = sr.processReceivedBlock(ctx, int64(headerHandler.GetRound()), []byte(sr.Leader())) + sr.PeerHonestyHandler().ChangeScore( + sr.Leader(), + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.LeaderPeerHonestyIncreaseFactor, + ) +} + +// CanProcessReceivedHeader method returns true if the received header can be processed and false otherwise +func (sr *subroundBlock) CanProcessReceivedHeader(headerLeader string) bool { + return sr.shouldProcessBlock(headerLeader) +} + +func (sr *subroundBlock) shouldProcessBlock(headerLeader string) bool { + if sr.IsNodeSelf(headerLeader) { + return false + } + if sr.IsJobDone(headerLeader, sr.Current()) { + return false + } + + if sr.IsSubroundFinished(sr.Current()) { + return false + } + + return true +} + +func (sr *subroundBlock) processReceivedBlock( + ctx context.Context, + round int64, + senderPK []byte, +) bool { + if check.IfNil(sr.GetBody()) { + return false + } + if check.IfNil(sr.GetHeader()) { + return false + } + + sw := core.NewStopWatch() + sw.Start("processReceivedBlock") + + sr.mutBlockProcessing.Lock() + defer sr.mutBlockProcessing.Unlock() + + defer func() { + sw.Stop("processReceivedBlock") + log.Info("time measurements of processReceivedBlock", sw.GetMeasurements()...) + + sr.SetProcessingBlock(false) + }() + + sr.SetProcessingBlock(true) + + shouldNotProcessBlock := sr.GetExtendedCalled() || round < sr.RoundHandler().Index() + if shouldNotProcessBlock { + log.Debug("canceled round, extended has been called or round index has been changed", + "round", sr.RoundHandler().Index(), + "subround", sr.Name(), + "cnsDta round", round, + "extended called", sr.GetExtendedCalled(), + ) + return false + } + + // check again under critical section to avoid double execution + if !sr.shouldProcessBlock(string(senderPK)) { + return false + } + + sw.Start("processBlock") + ok := sr.processBlock(ctx, round, senderPK) + sw.Stop("processBlock") + + return ok +} + +func (sr *subroundBlock) processBlock( + ctx context.Context, + roundIndex int64, + pubkey []byte, +) bool { + startTime := sr.GetRoundTimeStamp() + maxTime := sr.RoundHandler().TimeDuration() * time.Duration(sr.processingThresholdPercentage) / 100 + remainingTimeInCurrentRound := func() time.Duration { + return sr.RoundHandler().RemainingTime(startTime, maxTime) + } + + metricStatTime := time.Now() + defer sr.computeSubroundProcessingMetric(metricStatTime, common.MetricProcessedProposedBlock) + + err := sr.BlockProcessor().ProcessBlock( + sr.GetHeader(), + sr.GetBody(), + remainingTimeInCurrentRound, + ) + + if roundIndex < sr.RoundHandler().Index() { + log.Debug("canceled round, round index has been changed", + "round", sr.RoundHandler().Index(), + "subround", sr.Name(), + "cnsDta round", roundIndex, + ) + return false + } + + if err != nil { + sr.printCancelRoundLogMessage(ctx, err) + sr.SetRoundCanceled(true) + + return false + } + + node := string(pubkey) + err = sr.SetJobDone(node, sr.Current(), true) + if err != nil { + sr.printCancelRoundLogMessage(ctx, err) + return false + } + + sr.ConsensusCoreHandler.ScheduledProcessor().StartScheduledProcessing(sr.GetHeader(), sr.GetBody(), sr.GetRoundTimeStamp()) + + return true +} + +func (sr *subroundBlock) printCancelRoundLogMessage(ctx context.Context, err error) { + if common.IsContextDone(ctx) { + log.Debug("canceled round as the context is closing") + return + } + + log.Debug("canceled round", + "round", sr.RoundHandler().Index(), + "subround", sr.Name(), + "error", err.Error()) +} + +func (sr *subroundBlock) computeSubroundProcessingMetric(startTime time.Time, metric string) { + subRoundDuration := sr.EndTime() - sr.StartTime() + if subRoundDuration == 0 { + // can not do division by 0 + return + } + + percent := uint64(time.Since(startTime)) * 100 / uint64(subRoundDuration) + sr.AppStatusHandler().SetUInt64Value(metric, percent) +} + +// doBlockConsensusCheck method checks if the consensus in the subround Block is achieved +func (sr *subroundBlock) doBlockConsensusCheck() bool { + if sr.GetRoundCanceled() { + return false + } + + if sr.IsSubroundFinished(sr.Current()) { + return true + } + + threshold := sr.Threshold(sr.Current()) + if sr.isBlockReceived(threshold) { + log.Debug("step 1: subround has been finished", + "subround", sr.Name()) + sr.SetStatus(sr.Current(), spos.SsFinished) + return true + } + + return false +} + +// isBlockReceived method checks if the block was received from the leader in the current round +func (sr *subroundBlock) isBlockReceived(threshold int) bool { + n := 0 + + for i := 0; i < len(sr.ConsensusGroup()); i++ { + node := sr.ConsensusGroup()[i] + isJobDone, err := sr.JobDone(node, sr.Current()) + if err != nil { + log.Debug("isBlockReceived.JobDone", + "node", node, + "subround", sr.Name(), + "error", err.Error()) + continue + } + + if isJobDone { + n++ + } + } + + return n >= threshold +} + +func (sr *subroundBlock) getRoundInLastCommittedBlock() int64 { + roundInLastCommittedBlock := int64(0) + currentHeader := sr.Blockchain().GetCurrentBlockHeader() + if !check.IfNil(currentHeader) { + roundInLastCommittedBlock = int64(currentHeader.GetRound()) + } + + return roundInLastCommittedBlock +} + +// IsInterfaceNil returns true if there is no value under the interface +func (sr *subroundBlock) IsInterfaceNil() bool { + return sr == nil +} diff --git a/consensus/spos/bls/v2/subroundBlock_test.go b/consensus/spos/bls/v2/subroundBlock_test.go new file mode 100644 index 00000000000..12a82fd6529 --- /dev/null +++ b/consensus/spos/bls/v2/subroundBlock_test.go @@ -0,0 +1,1359 @@ +package v2_test + +import ( + "errors" + "fmt" + "math/big" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" + "github.com/multiversx/mx-chain-go/testscommon" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +var expectedErr = errors.New("expected error") + +func defaultSubroundForSRBlock(consensusState *spos.ConsensusState, ch chan bool, + container *spos.ConsensusCore, appStatusHandler core.AppStatusHandler) (*spos.Subround, error) { + return spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) +} + +func createDefaultHeader() *block.Header { + return &block.Header{ + Nonce: 1, + PrevHash: []byte("prev hash"), + PrevRandSeed: []byte("prev rand seed"), + RandSeed: []byte("rand seed"), + RootHash: []byte("roothash"), + TxCount: 0, + ChainID: []byte("chain ID"), + SoftwareVersion: []byte("software version"), + AccumulatedFees: big.NewInt(0), + DeveloperFees: big.NewInt(0), + } +} + +func defaultSubroundBlockFromSubround(sr *spos.Subround) (v2.SubroundBlock, error) { + srBlock, err := v2.NewSubroundBlock( + sr, + v2.ProcessingThresholdPercent, + &consensusMocks.SposWorkerMock{}, + ) + + return srBlock, err +} + +func defaultSubroundBlockWithoutErrorFromSubround(sr *spos.Subround) v2.SubroundBlock { + srBlock, _ := v2.NewSubroundBlock( + sr, + v2.ProcessingThresholdPercent, + &consensusMocks.SposWorkerMock{}, + ) + + return srBlock +} + +func initSubroundBlock( + blockChain data.ChainHandler, + container *spos.ConsensusCore, + appStatusHandler core.AppStatusHandler, +) v2.SubroundBlock { + if blockChain == nil { + blockChain = &testscommon.ChainHandlerStub{ + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return &block.Header{} + }, + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.Header{ + Nonce: uint64(0), + Signature: []byte("genesis signature"), + RandSeed: []byte{0}, + } + }, + GetGenesisHeaderHashCalled: func() []byte { + return []byte("genesis header hash") + }, + } + } + + consensusState := initializers.InitConsensusStateWithNodesCoordinator(container.NodesCoordinator()) + ch := make(chan bool, 1) + + container.SetBlockchain(blockChain) + + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, appStatusHandler) + srBlock, _ := defaultSubroundBlockFromSubround(sr) + return srBlock +} + +func createConsensusContainers() []*spos.ConsensusCore { + consensusContainers := make([]*spos.ConsensusCore, 0) + container := consensusMocks.InitConsensusCore() + consensusContainers = append(consensusContainers, container) + container = consensusMocks.InitConsensusCoreHeaderV2() + consensusContainers = append(consensusContainers, container) + return consensusContainers +} + +func initSubroundBlockWithBlockProcessor( + bp *testscommon.BlockProcessorStub, + container *spos.ConsensusCore, +) v2.SubroundBlock { + blockChain := &testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.Header{ + Nonce: uint64(0), + Signature: []byte("genesis signature"), + } + }, + GetGenesisHeaderHashCalled: func() []byte { + return []byte("genesis header hash") + }, + } + blockProcessorMock := bp + + container.SetBlockchain(blockChain) + container.SetBlockProcessor(blockProcessorMock) + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + srBlock, _ := defaultSubroundBlockFromSubround(sr) + return srBlock +} + +func TestSubroundBlock_NewSubroundBlockNilSubroundShouldFail(t *testing.T) { + t.Parallel() + + srBlock, err := v2.NewSubroundBlock( + nil, + v2.ProcessingThresholdPercent, + &consensusMocks.SposWorkerMock{}, + ) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilSubround, err) +} + +func TestSubroundBlock_NewSubroundBlockNilBlockchainShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetBlockchain(nil) + + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilBlockChain, err) +} + +func TestSubroundBlock_NewSubroundBlockNilBlockProcessorShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetBlockProcessor(nil) + + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilBlockProcessor, err) +} + +func TestSubroundBlock_NewSubroundBlockNilConsensusStateShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + sr.ConsensusStateHandler = nil + + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilConsensusState, err) +} + +func TestSubroundBlock_NewSubroundBlockNilHasherShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetHasher(nil) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilHasher, err) +} + +func TestSubroundBlock_NewSubroundBlockNilMarshalizerShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetMarshalizer(nil) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilMarshalizer, err) +} + +func TestSubroundBlock_NewSubroundBlockNilMultiSignerContainerShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetMultiSignerContainer(nil) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilMultiSignerContainer, err) +} + +func TestSubroundBlock_NewSubroundBlockNilRoundHandlerShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(nil) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilRoundHandler, err) +} + +func TestSubroundBlock_NewSubroundBlockNilShardCoordinatorShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetShardCoordinator(nil) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilShardCoordinator, err) +} + +func TestSubroundBlock_NewSubroundBlockNilSyncTimerShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetSyncTimer(nil) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestSubroundBlock_NewSubroundBlockNilWorkerShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + + srBlock, err := v2.NewSubroundBlock( + sr, + v2.ProcessingThresholdPercent, + nil, + ) + assert.Nil(t, srBlock) + assert.Equal(t, spos.ErrNilWorker, err) +} + +func TestSubroundBlock_NewSubroundBlockShouldWork(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + srBlock, err := defaultSubroundBlockFromSubround(sr) + assert.NotNil(t, srBlock) + assert.Nil(t, err) +} + +func TestSubroundBlock_DoBlockJob(t *testing.T) { + t.Parallel() + + t.Run("not leader should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("round index lower than last committed block should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + _ = sr.SetJobDone(sr.SelfPubKey(), bls.SrBlock, true) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("leader job done should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + _ = sr.SetJobDone(sr.SelfPubKey(), bls.SrBlock, true) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("subround finished should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + _ = sr.SetJobDone(sr.SelfPubKey(), bls.SrBlock, false) + sr.SetStatus(bls.SrBlock, spos.SsFinished) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("create header error should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + sr.SetStatus(bls.SrBlock, spos.SsNotFinished) + bpm := &testscommon.BlockProcessorStub{} + + bpm.CreateNewHeaderCalled = func(round uint64, nonce uint64) (data.HeaderHandler, error) { + return nil, expectedErr + } + container.SetBlockProcessor(bpm) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("create block error should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + sr.SetStatus(bls.SrBlock, spos.SsNotFinished) + bpm := &testscommon.BlockProcessorStub{} + bpm.CreateBlockCalled = func(header data.HeaderHandler, remainingTime func() bool) (data.HeaderHandler, data.BodyHandler, error) { + return header, nil, expectedErr + } + bpm.CreateNewHeaderCalled = func(round uint64, nonce uint64) (data.HeaderHandler, error) { + return &block.Header{}, nil + } + container.SetBlockProcessor(bpm) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("sign block header failure should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + + bpm := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + container.SetBlockProcessor(bpm) + cnt := uint32(0) + sh := &consensusMocks.SigningHandlerStub{ + CreateSignatureForPublicKeyCalled: func(message []byte, publicKeyBytes []byte) ([]byte, error) { + cnt++ + if cnt > 1 { // first call is from create header + return nil, expectedErr + } + + return []byte("sig"), nil + }, + } + container.SetSigningHandler(sh) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("set leader signature failure should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + + bpm := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + bpm.CreateNewHeaderCalled = func(round uint64, nonce uint64) (data.HeaderHandler, error) { + return &testscommon.HeaderHandlerStub{ + SetLeaderSignatureCalled: func(signature []byte) error { + if len(signature) > 0 { + return expectedErr + } + return nil + }, + CloneCalled: func() data.HeaderHandler { + return &block.HeaderV2{ + Header: &block.Header{}, + } + }, + }, nil + } + container.SetBlockProcessor(bpm) + sh := &consensusMocks.SigningHandlerStub{ + CreateSignatureForPublicKeyCalled: func(message []byte, publicKeyBytes []byte) ([]byte, error) { + return []byte("sig"), nil + }, + } + container.SetSigningHandler(sh) + + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("send block error should return false", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + bpm := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + container.SetBlockProcessor(bpm) + bm := &consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + return expectedErr + }, + } + container.SetBroadcastMessenger(bm) + r := sr.DoBlockJob() + assert.False(t, r) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + providedSignature := []byte("provided signature") + providedBitmap := []byte("provided bitmap") + providedHash := []byte("provided hash") + providedHeadr := &block.HeaderV2{ + Header: &block.Header{ + Signature: []byte("signature"), + PubKeysBitmap: []byte("bitmap"), + }, + } + + container := consensusMocks.InitConsensusCore() + chainHandler := &testscommon.ChainHandlerStub{ + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return providedHeadr + }, + GetCurrentBlockHeaderHashCalled: func() []byte { + return providedHash + }, + } + container.SetBlockchain(chainHandler) + + consensusState := initializers.InitConsensusStateWithNodesCoordinator(container.NodesCoordinator()) + ch := make(chan bool, 1) + + baseSr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + sr, _ := v2.NewSubroundBlock( + baseSr, + v2.ProcessingThresholdPercent, + &consensusMocks.SposWorkerMock{}, + ) + + providedLeaderSignature := []byte("leader signature") + container.SetSigningHandler(&consensusMocks.SigningHandlerStub{ + CreateSignatureForPublicKeyCalled: func(message []byte, publicKeyBytes []byte) ([]byte, error) { + return providedLeaderSignature, nil + }, + VerifySignatureShareCalled: func(index uint16, sig []byte, msg []byte, epoch uint32) error { + assert.Fail(t, "should have not been called for leader") + return nil + }, + }) + container.SetRoundHandler(&testscommon.RoundHandlerMock{ + IndexCalled: func() int64 { + return 1 + }, + }) + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + + sr.SetSelfPubKey(leader) + bpm := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + container.SetBlockProcessor(bpm) + bpm.CreateNewHeaderCalled = func(round uint64, nonce uint64) (data.HeaderHandler, error) { + return &block.HeaderV2{ + Header: &block.Header{ + Round: round, + Nonce: nonce, + }, + }, nil + } + bm := &consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + return nil + }, + } + container.SetBroadcastMessenger(bm) + container.SetRoundHandler(&consensusMocks.RoundHandlerMock{ + RoundIndex: 1, + }) + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + return &block.HeaderProof{ + HeaderHash: headerHash, + AggregatedSignature: providedSignature, + PubKeysBitmap: providedBitmap, + }, nil + }, + }) + + r := sr.DoBlockJob() + assert.True(t, r) + assert.Equal(t, uint64(1), sr.GetHeader().GetNonce()) + }) +} + +func TestSubroundBlock_ReceivedBlock(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + blkBody := &block.Body{} + blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + leader, err := sr.GetLeader() + assert.Nil(t, err) + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + blkBodyStr, + nil, + []byte(leader), + []byte("sig"), + int(bls.MtBlockBody), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + sr.SetBody(&block.Body{}) + r := sr.ReceivedBlockBody(cnsMsg) + assert.False(t, r) + + sr.SetBody(nil) + cnsMsg.PubKey = []byte(sr.ConsensusGroup()[1]) + r = sr.ReceivedBlockBody(cnsMsg) + assert.False(t, r) + + cnsMsg.PubKey = []byte(sr.ConsensusGroup()[0]) + sr.SetStatus(bls.SrBlock, spos.SsFinished) + r = sr.ReceivedBlockBody(cnsMsg) + assert.False(t, r) + + sr.SetStatus(bls.SrBlock, spos.SsNotFinished) + r = sr.ReceivedBlockBody(cnsMsg) + assert.False(t, r) +} + +func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenBodyAndHeaderAreNotSet(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + leader, _ := sr.GetLeader() + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + nil, + nil, + []byte(leader), + []byte("sig"), + int(bls.MtBlockBodyAndHeader), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + assert.False(t, sr.ProcessReceivedBlock(cnsMsg)) +} + +func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockFails(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + blProcMock := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + err := errors.New("error process block") + blProcMock.ProcessBlockCalled = func(data.HeaderHandler, data.BodyHandler, func() time.Duration) error { + return err + } + container.SetBlockProcessor(blProcMock) + hdr := &block.Header{} + blkBody := &block.Body{} + blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + leader, _ := sr.GetLeader() + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + blkBodyStr, + nil, + []byte(leader), + []byte("sig"), + int(bls.MtBlockBody), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + sr.SetHeader(hdr) + sr.SetBody(blkBody) + assert.False(t, sr.ProcessReceivedBlock(cnsMsg)) +} + +func TestSubroundBlock_ProcessReceivedBlockShouldReturnFalseWhenProcessBlockReturnsInNextRound(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + hdr := &block.Header{} + blkBody := &block.Body{} + blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + leader, _ := sr.GetLeader() + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + blkBodyStr, + nil, + []byte(leader), + []byte("sig"), + int(bls.MtBlockBody), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + sr.SetHeader(hdr) + sr.SetBody(blkBody) + blockProcessorMock := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + blockProcessorMock.ProcessBlockCalled = func(header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration) error { + return expectedErr + } + container.SetBlockProcessor(blockProcessorMock) + container.SetRoundHandler(&consensusMocks.RoundHandlerMock{RoundIndex: 1}) + assert.False(t, sr.ProcessReceivedBlock(cnsMsg)) +} + +func TestSubroundBlock_ProcessReceivedBlockShouldReturnTrue(t *testing.T) { + t.Parallel() + + consensusContainers := createConsensusContainers() + for _, container := range consensusContainers { + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + hdr, _ := container.BlockProcessor().CreateNewHeader(1, 1) + hdr, blkBody, _ := container.BlockProcessor().CreateBlock(hdr, func() bool { return true }) + + blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + leader, _ := sr.GetLeader() + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + blkBodyStr, + nil, + []byte(leader), + []byte("sig"), + int(bls.MtBlockBody), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + sr.SetHeader(hdr) + sr.SetBody(blkBody) + assert.True(t, sr.ProcessReceivedBlock(cnsMsg)) + } +} + +func TestSubroundBlock_RemainingTimeShouldReturnNegativeValue(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + roundHandlerMock := initRoundHandlerMock() + container.SetRoundHandler(roundHandlerMock) + + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + remainingTimeInThisRound := func() time.Duration { + roundStartTime := sr.RoundHandler().TimeStamp() + currentTime := sr.SyncTimer().CurrentTime() + elapsedTime := currentTime.Sub(roundStartTime) + remainingTime := sr.RoundHandler().TimeDuration()*85/100 - elapsedTime + + return remainingTime + } + container.SetSyncTimer(&consensusMocks.SyncTimerMock{CurrentTimeCalled: func() time.Time { + return time.Unix(0, 0).Add(roundTimeDuration * 84 / 100) + }}) + ret := remainingTimeInThisRound() + assert.True(t, ret > 0) + + container.SetSyncTimer(&consensusMocks.SyncTimerMock{CurrentTimeCalled: func() time.Time { + return time.Unix(0, 0).Add(roundTimeDuration * 85 / 100) + }}) + ret = remainingTimeInThisRound() + assert.True(t, ret == 0) + + container.SetSyncTimer(&consensusMocks.SyncTimerMock{CurrentTimeCalled: func() time.Time { + return time.Unix(0, 0).Add(roundTimeDuration * 86 / 100) + }}) + ret = remainingTimeInThisRound() + assert.True(t, ret < 0) +} + +func TestSubroundBlock_DoBlockConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr.SetRoundCanceled(true) + assert.False(t, sr.DoBlockConsensusCheck()) +} + +func TestSubroundBlock_DoBlockConsensusCheckShouldReturnTrueWhenSubroundIsFinished(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + sr.SetStatus(bls.SrBlock, spos.SsFinished) + assert.True(t, sr.DoBlockConsensusCheck()) +} + +func TestSubroundBlock_DoBlockConsensusCheckShouldReturnTrueWhenBlockIsReceivedReturnTrue(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + for i := 0; i < sr.Threshold(bls.SrBlock); i++ { + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrBlock, true) + } + assert.True(t, sr.DoBlockConsensusCheck()) +} + +func TestSubroundBlock_DoBlockConsensusCheckShouldReturnFalseWhenBlockIsReceivedReturnFalse(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + assert.False(t, sr.DoBlockConsensusCheck()) +} + +func TestSubroundBlock_IsBlockReceived(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + for i := 0; i < len(sr.ConsensusGroup()); i++ { + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrBlock, false) + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, false) + } + ok := sr.IsBlockReceived(1) + assert.False(t, ok) + + _ = sr.SetJobDone("A", bls.SrBlock, true) + isJobDone, _ := sr.JobDone("A", bls.SrBlock) + assert.True(t, isJobDone) + + ok = sr.IsBlockReceived(1) + assert.True(t, ok) + + ok = sr.IsBlockReceived(2) + assert.False(t, ok) +} + +func TestSubroundBlock_HaveTimeInCurrentSubroundShouldReturnTrue(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + haveTimeInCurrentSubound := func() bool { + roundStartTime := sr.RoundHandler().TimeStamp() + currentTime := sr.SyncTimer().CurrentTime() + elapsedTime := currentTime.Sub(roundStartTime) + remainingTime := sr.EndTime() - int64(elapsedTime) + + return time.Duration(remainingTime) > 0 + } + roundHandlerMock := &consensusMocks.RoundHandlerMock{} + roundHandlerMock.TimeDurationCalled = func() time.Duration { + return 4000 * time.Millisecond + } + roundHandlerMock.TimeStampCalled = func() time.Time { + return time.Unix(0, 0) + } + syncTimerMock := &consensusMocks.SyncTimerMock{} + timeElapsed := sr.EndTime() - 1 + syncTimerMock.CurrentTimeCalled = func() time.Time { + return time.Unix(0, timeElapsed) + } + container.SetRoundHandler(roundHandlerMock) + container.SetSyncTimer(syncTimerMock) + + assert.True(t, haveTimeInCurrentSubound()) +} + +func TestSubroundBlock_HaveTimeInCurrentSuboundShouldReturnFalse(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + haveTimeInCurrentSubound := func() bool { + roundStartTime := sr.RoundHandler().TimeStamp() + currentTime := sr.SyncTimer().CurrentTime() + elapsedTime := currentTime.Sub(roundStartTime) + remainingTime := sr.EndTime() - int64(elapsedTime) + + return time.Duration(remainingTime) > 0 + } + roundHandlerMock := &consensusMocks.RoundHandlerMock{} + roundHandlerMock.TimeDurationCalled = func() time.Duration { + return 4000 * time.Millisecond + } + roundHandlerMock.TimeStampCalled = func() time.Time { + return time.Unix(0, 0) + } + syncTimerMock := &consensusMocks.SyncTimerMock{} + timeElapsed := sr.EndTime() + 1 + syncTimerMock.CurrentTimeCalled = func() time.Time { + return time.Unix(0, timeElapsed) + } + container.SetRoundHandler(roundHandlerMock) + container.SetSyncTimer(syncTimerMock) + + assert.False(t, haveTimeInCurrentSubound()) +} + +func TestSubroundBlock_CreateHeaderNilCurrentHeader(t *testing.T) { + blockChain := &testscommon.ChainHandlerStub{ + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return nil + }, + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.Header{ + Nonce: uint64(0), + Signature: []byte("genesis signature"), + RandSeed: []byte{0}, + } + }, + GetGenesisHeaderHashCalled: func() []byte { + return []byte("genesis header hash") + }, + } + + consensusContainers := createConsensusContainers() + for _, container := range consensusContainers { + sr := initSubroundBlock(blockChain, container, &statusHandler.AppStatusHandlerStub{}) + _ = sr.BlockChain().SetCurrentBlockHeaderAndRootHash(nil, nil) + header, _ := sr.CreateHeader() + header, body, _ := sr.CreateBlock(header) + marshalizedBody, _ := sr.Marshalizer().Marshal(body) + marshalizedHeader, _ := sr.Marshalizer().Marshal(header) + _ = sr.SendBlockBody(body, marshalizedBody) + _ = sr.SendBlockHeader(header, marshalizedHeader) + + expectedHeader, _ := container.BlockProcessor().CreateNewHeader(uint64(sr.RoundHandler().Index()), uint64(1)) + err := expectedHeader.SetTimeStamp(uint64(sr.RoundHandler().TimeStamp().Unix())) + require.Nil(t, err) + err = expectedHeader.SetRootHash([]byte{}) + require.Nil(t, err) + err = expectedHeader.SetPrevHash(sr.BlockChain().GetGenesisHeaderHash()) + require.Nil(t, err) + err = expectedHeader.SetPrevRandSeed(sr.BlockChain().GetGenesisHeader().GetRandSeed()) + require.Nil(t, err) + err = expectedHeader.SetRandSeed(make([]byte, 0)) + require.Nil(t, err) + err = expectedHeader.SetMiniBlockHeaderHandlers(header.GetMiniBlockHeaderHandlers()) + require.Nil(t, err) + err = expectedHeader.SetChainID(chainID) + require.Nil(t, err) + require.Equal(t, expectedHeader, header) + } +} + +func TestSubroundBlock_CreateHeaderNotNilCurrentHeader(t *testing.T) { + consensusContainers := createConsensusContainers() + for _, container := range consensusContainers { + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + _ = sr.BlockChain().SetCurrentBlockHeaderAndRootHash(&block.Header{ + Nonce: 1, + }, []byte("root hash")) + + header, _ := sr.CreateHeader() + header, body, _ := sr.CreateBlock(header) + marshalizedBody, _ := sr.Marshalizer().Marshal(body) + marshalizedHeader, _ := sr.Marshalizer().Marshal(header) + _ = sr.SendBlockBody(body, marshalizedBody) + _ = sr.SendBlockHeader(header, marshalizedHeader) + + expectedHeader, _ := container.BlockProcessor().CreateNewHeader( + uint64(sr.RoundHandler().Index()), + sr.BlockChain().GetCurrentBlockHeader().GetNonce()+1) + err := expectedHeader.SetTimeStamp(uint64(sr.RoundHandler().TimeStamp().Unix())) + require.Nil(t, err) + err = expectedHeader.SetRootHash([]byte{}) + require.Nil(t, err) + err = expectedHeader.SetPrevHash(sr.BlockChain().GetCurrentBlockHeaderHash()) + require.Nil(t, err) + err = expectedHeader.SetRandSeed(make([]byte, 0)) + require.Nil(t, err) + err = expectedHeader.SetMiniBlockHeaderHandlers(header.GetMiniBlockHeaderHandlers()) + require.Nil(t, err) + err = expectedHeader.SetChainID(chainID) + require.Nil(t, err) + require.Equal(t, expectedHeader, header) + } +} + +func TestSubroundBlock_CreateHeaderMultipleMiniBlocks(t *testing.T) { + mbHeaders := []block.MiniBlockHeader{ + {Hash: []byte("mb1"), SenderShardID: 1, ReceiverShardID: 1}, + {Hash: []byte("mb2"), SenderShardID: 1, ReceiverShardID: 2}, + {Hash: []byte("mb3"), SenderShardID: 2, ReceiverShardID: 3}, + } + blockChainMock := testscommon.ChainHandlerStub{ + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return &block.Header{ + Nonce: 1, + } + }, + } + container := consensusMocks.InitConsensusCore() + bp := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + bp.CreateBlockCalled = func(header data.HeaderHandler, haveTime func() bool) (data.HeaderHandler, data.BodyHandler, error) { + shardHeader, _ := header.(*block.Header) + shardHeader.MiniBlockHeaders = mbHeaders + shardHeader.RootHash = []byte{} + + return shardHeader, &block.Body{}, nil + } + sr := initSubroundBlockWithBlockProcessor(bp, container) + container.SetBlockchain(&blockChainMock) + + header, _ := sr.CreateHeader() + header, body, _ := sr.CreateBlock(header) + marshalizedBody, _ := sr.Marshalizer().Marshal(body) + marshalizedHeader, _ := sr.Marshalizer().Marshal(header) + _ = sr.SendBlockBody(body, marshalizedBody) + _ = sr.SendBlockHeader(header, marshalizedHeader) + + expectedHeader := &block.Header{ + Round: uint64(sr.RoundHandler().Index()), + TimeStamp: uint64(sr.RoundHandler().TimeStamp().Unix()), + RootHash: []byte{}, + Nonce: sr.BlockChain().GetCurrentBlockHeader().GetNonce() + 1, + PrevHash: sr.BlockChain().GetCurrentBlockHeaderHash(), + RandSeed: make([]byte, 0), + MiniBlockHeaders: mbHeaders, + ChainID: chainID, + } + + assert.Equal(t, expectedHeader, header) +} + +func TestSubroundBlock_CreateHeaderNilMiniBlocks(t *testing.T) { + expectedErr := errors.New("nil mini blocks") + container := consensusMocks.InitConsensusCore() + bp := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + bp.CreateBlockCalled = func(header data.HeaderHandler, haveTime func() bool) (data.HeaderHandler, data.BodyHandler, error) { + return nil, nil, expectedErr + } + sr := initSubroundBlockWithBlockProcessor(bp, container) + _ = sr.BlockChain().SetCurrentBlockHeaderAndRootHash(&block.Header{ + Nonce: 1, + }, []byte("root hash")) + header, _ := sr.CreateHeader() + _, _, err := sr.CreateBlock(header) + assert.Equal(t, expectedErr, err) +} + +func TestSubroundBlock_CallFuncRemainingTimeWithStructShouldWork(t *testing.T) { + roundStartTime := time.Now() + maxTime := 100 * time.Millisecond + newRoundStartTime := roundStartTime + remainingTimeInCurrentRound := func() time.Duration { + return RemainingTimeWithStruct(newRoundStartTime, maxTime) + } + assert.True(t, remainingTimeInCurrentRound() > 0) + + time.Sleep(200 * time.Millisecond) + assert.True(t, remainingTimeInCurrentRound() < 0) +} + +func TestSubroundBlock_CallFuncRemainingTimeWithStructShouldNotWork(t *testing.T) { + roundStartTime := time.Now() + maxTime := 100 * time.Millisecond + remainingTimeInCurrentRound := func() time.Duration { + return RemainingTimeWithStruct(roundStartTime, maxTime) + } + assert.True(t, remainingTimeInCurrentRound() > 0) + + time.Sleep(200 * time.Millisecond) + assert.True(t, remainingTimeInCurrentRound() < 0) + + roundStartTime = roundStartTime.Add(500 * time.Millisecond) + assert.False(t, remainingTimeInCurrentRound() < 0) +} + +func RemainingTimeWithStruct(startTime time.Time, maxTime time.Duration) time.Duration { + currentTime := time.Now() + elapsedTime := currentTime.Sub(startTime) + remainingTime := maxTime - elapsedTime + return remainingTime +} + +func TestSubroundBlock_ReceivedBlockComputeProcessDuration(t *testing.T) { + t.Parallel() + + srStartTime := int64(5 * roundTimeDuration / 100) + srEndTime := int64(25 * roundTimeDuration / 100) + srDuration := srEndTime - srStartTime + delay := srDuration * 430 / 1000 + + container := consensusMocks.InitConsensusCore() + receivedValue := uint64(0) + container.SetBlockProcessor(&testscommon.BlockProcessorStub{ + ProcessBlockCalled: func(_ data.HeaderHandler, _ data.BodyHandler, _ func() time.Duration) error { + time.Sleep(time.Duration(delay)) + return nil + }, + }) + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{ + SetUInt64ValueHandler: func(key string, value uint64) { + receivedValue = value + }}) + hdr := &block.Header{} + blkBody := &block.Body{} + blkBodyStr, _ := mock.MarshalizerMock{}.Marshal(blkBody) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + cnsMsg := consensus.NewConsensusMessage( + nil, + nil, + blkBodyStr, + nil, + []byte(leader), + []byte("sig"), + int(bls.MtBlockBody), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + sr.SetHeader(hdr) + sr.SetBody(blkBody) + + minimumExpectedValue := uint64(delay * 100 / srDuration) + _ = sr.ProcessReceivedBlock(cnsMsg) + + assert.True(t, + receivedValue >= minimumExpectedValue, + fmt.Sprintf("minimum expected was %d, got %d", minimumExpectedValue, receivedValue), + ) +} + +func TestSubroundBlock_ReceivedBlockComputeProcessDurationWithZeroDurationShouldNotPanic(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r != nil { + assert.Fail(t, "should not have paniced", r) + } + }() + + container := consensusMocks.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubroundForSRBlock(consensusState, ch, container, &statusHandler.AppStatusHandlerStub{}) + srBlock := defaultSubroundBlockWithoutErrorFromSubround(sr) + + srBlock.ComputeSubroundProcessingMetric(time.Now(), "dummy") +} + +func TestSubroundBlock_ReceivedBlockHeader(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + prevHash := []byte("header hash") + prevHeader := createDefaultHeader() + blockchain := &testscommon.ChainHandlerStub{ + GetCurrentBlockHeaderHashCalled: func() []byte { + return prevHash + }, + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return &block.HeaderV2{ + Header: prevHeader, + } + }, + } + container.SetBlockchain(blockchain) + + // nil header + sr.ReceivedBlockHeader(nil) + + // header not for current consensus + sr.ReceivedBlockHeader(&testscommon.HeaderHandlerStub{}) + + // nil fields on header + sr.ReceivedBlockHeader(&testscommon.HeaderHandlerStub{ + CheckFieldsForNilCalled: func() error { + return expectedErr + }, + }) + + // header not for current consensus + sr.ReceivedBlockHeader(&testscommon.HeaderHandlerStub{}) + + headerForCurrentConsensus := &testscommon.HeaderHandlerStub{ + GetShardIDCalled: func() uint32 { + return container.ShardCoordinator().SelfId() + }, + RoundField: uint64(container.RoundHandler().Index()), + GetPrevHashCalled: func() []byte { + return prevHash + }, + GetNonceCalled: func() uint64 { + return prevHeader.GetNonce() + 1 + }, + GetPrevRandSeedCalled: func() []byte { + return prevHeader.RandSeed + }, + } + + // leader + defaultLeader := sr.Leader() + sr.SetLeader(sr.SelfPubKey()) + sr.ReceivedBlockHeader(headerForCurrentConsensus) + sr.SetLeader(defaultLeader) + + // consensus data already set + sr.SetData([]byte("some data")) + sr.ReceivedBlockHeader(headerForCurrentConsensus) + sr.SetData(nil) + + // header leader is not the current one + sr.SetLeader("X") + sr.ReceivedBlockHeader(headerForCurrentConsensus) + sr.SetLeader(defaultLeader) + + // header already received + sr.SetHeader(&testscommon.HeaderHandlerStub{}) + sr.ReceivedBlockHeader(headerForCurrentConsensus) + sr.SetHeader(nil) + + // self job already done + _ = sr.SetJobDone(sr.SelfPubKey(), sr.Current(), true) + sr.ReceivedBlockHeader(headerForCurrentConsensus) + _ = sr.SetJobDone(sr.SelfPubKey(), sr.Current(), false) + + // subround already finished + sr.SetStatus(sr.Current(), spos.SsFinished) + sr.ReceivedBlockHeader(headerForCurrentConsensus) + sr.SetStatus(sr.Current(), spos.SsNotFinished) + + // marshal error + container.SetMarshalizer(&testscommon.MarshallerStub{ + MarshalCalled: func(obj interface{}) ([]byte, error) { + return nil, expectedErr + }, + }) + sr.ReceivedBlockHeader(headerForCurrentConsensus) + container.SetMarshalizer(&testscommon.MarshallerStub{}) + + // should work + sr.ReceivedBlockHeader(headerForCurrentConsensus) +} + +func TestSubroundBlock_GetLeaderForHeader(t *testing.T) { + t.Parallel() + + t.Run("should fail if not able to compute consensus group", func(t *testing.T) { + t.Parallel() + + expErr := errors.New("expected error") + + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + container.SetNodesCoordinator(&shardingMocks.NodesCoordinatorStub{ + ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return nil, nil, expErr + }, + }) + + leader, err := sr.GetLeaderForHeader(&block.Header{ + Epoch: 10, + }) + + require.Nil(t, leader) + require.Equal(t, expErr, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + + expLeader := shardingMocks.NewValidatorMock([]byte("pubKey"), 1, 1) + + container.SetNodesCoordinator(&shardingMocks.NodesCoordinatorStub{ + ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return expLeader, make([]nodesCoordinator.Validator, 0), nil + }, + }) + + leader, err := sr.GetLeaderForHeader(&block.Header{ + Epoch: 10, + }) + + require.Nil(t, err) + require.Equal(t, expLeader.PubKey(), leader) + }) +} + +func TestSubroundBlock_IsInterfaceNil(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundBlock(nil, container, nil) + require.True(t, sr.IsInterfaceNil()) + + sr = initSubroundBlock(nil, container, &statusHandler.AppStatusHandlerStub{}) + require.False(t, sr.IsInterfaceNil()) +} diff --git a/consensus/spos/bls/v2/subroundEndRound.go b/consensus/spos/bls/v2/subroundEndRound.go new file mode 100644 index 00000000000..66c585d1b9b --- /dev/null +++ b/consensus/spos/bls/v2/subroundEndRound.go @@ -0,0 +1,973 @@ +package v2 + +import ( + "bytes" + "context" + "encoding/hex" + "errors" + "fmt" + "math/rand" + "sync" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-core-go/display" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + "github.com/multiversx/mx-chain-go/p2p" + "github.com/multiversx/mx-chain-go/process/headerCheck" +) + +const timeBetweenSignaturesChecks = time.Millisecond * 5 + +type subroundEndRound struct { + *spos.Subround + processingThresholdPercentage int + appStatusHandler core.AppStatusHandler + mutProcessingEndRound sync.Mutex + sentSignatureTracker spos.SentSignaturesTracker + worker spos.WorkerHandler + signatureThrottler core.Throttler +} + +// NewSubroundEndRound creates a subroundEndRound object +func NewSubroundEndRound( + baseSubround *spos.Subround, + processingThresholdPercentage int, + appStatusHandler core.AppStatusHandler, + sentSignatureTracker spos.SentSignaturesTracker, + worker spos.WorkerHandler, + signatureThrottler core.Throttler, +) (*subroundEndRound, error) { + err := checkNewSubroundEndRoundParams(baseSubround) + if err != nil { + return nil, err + } + if check.IfNil(appStatusHandler) { + return nil, spos.ErrNilAppStatusHandler + } + if check.IfNil(sentSignatureTracker) { + return nil, ErrNilSentSignatureTracker + } + if check.IfNil(worker) { + return nil, spos.ErrNilWorker + } + if check.IfNil(signatureThrottler) { + return nil, spos.ErrNilThrottler + } + + srEndRound := subroundEndRound{ + Subround: baseSubround, + processingThresholdPercentage: processingThresholdPercentage, + appStatusHandler: appStatusHandler, + mutProcessingEndRound: sync.Mutex{}, + sentSignatureTracker: sentSignatureTracker, + worker: worker, + signatureThrottler: signatureThrottler, + } + srEndRound.Job = srEndRound.doEndRoundJob + srEndRound.Check = srEndRound.doEndRoundConsensusCheck + srEndRound.Extend = worker.Extend + + return &srEndRound, nil +} + +func checkNewSubroundEndRoundParams( + baseSubround *spos.Subround, +) error { + if baseSubround == nil { + return spos.ErrNilSubround + } + if check.IfNil(baseSubround.ConsensusStateHandler) { + return spos.ErrNilConsensusState + } + + err := spos.ValidateConsensusCore(baseSubround.ConsensusCoreHandler) + + return err +} + +func (sr *subroundEndRound) isProofForCurrentConsensus(proof consensus.ProofHandler) bool { + return bytes.Equal(sr.GetData(), proof.GetHeaderHash()) +} + +// receivedProof method is called when a block header final info is received +func (sr *subroundEndRound) receivedProof(proof consensus.ProofHandler) { + sr.mutProcessingEndRound.Lock() + defer sr.mutProcessingEndRound.Unlock() + + if sr.IsSelfJobDone(sr.Current()) { + return + } + if !sr.IsConsensusDataSet() { + return + } + if check.IfNil(sr.GetHeader()) { + return + } + if !sr.isProofForCurrentConsensus(proof) { + return + } + + // no need to re-verify the proof since it was already verified when it was added to the proofs pool + log.Debug("step 3: block header final info has been received", + "PubKeysBitmap", proof.GetPubKeysBitmap(), + "AggregateSignature", proof.GetAggregatedSignature(), + "HederHash", proof.GetHeaderHash()) + + sr.doEndRoundJobByNode() +} + +// receivedInvalidSignersInfo method is called when a message with invalid signers has been received +func (sr *subroundEndRound) receivedInvalidSignersInfo(_ context.Context, cnsDta *consensus.Message) bool { + messageSender := string(cnsDta.PubKey) + + if !sr.IsConsensusDataSet() { + return false + } + if check.IfNil(sr.GetHeader()) { + return false + } + + isSelfSender := sr.IsNodeSelf(messageSender) || sr.IsKeyManagedBySelf([]byte(messageSender)) + if isSelfSender { + return false + } + + if !sr.IsConsensusDataEqual(cnsDta.BlockHeaderHash) { + return false + } + + if !sr.CanProcessReceivedMessage(cnsDta, sr.RoundHandler().Index(), sr.Current()) { + return false + } + + if len(cnsDta.InvalidSigners) == 0 { + return false + } + + invalidSignersCache := sr.InvalidSignersCache() + if invalidSignersCache.CheckKnownInvalidSigners(cnsDta.BlockHeaderHash, cnsDta.InvalidSigners) { + return false + } + + invalidSignersPubKeys, err := sr.verifyInvalidSigners(cnsDta.InvalidSigners) + if err != nil { + log.Trace("receivedInvalidSignersInfo.verifyInvalidSigners", "error", err.Error()) + return false + } + + log.Debug("step 3: invalid signers info has been evaluated") + + invalidSignersCache.AddInvalidSigners(cnsDta.BlockHeaderHash, cnsDta.InvalidSigners, invalidSignersPubKeys) + + sr.PeerHonestyHandler().ChangeScore( + messageSender, + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.LeaderPeerHonestyIncreaseFactor, + ) + + return true +} + +func (sr *subroundEndRound) verifyInvalidSigners(invalidSigners []byte) ([]string, error) { + messages, err := sr.MessageSigningHandler().Deserialize(invalidSigners) + if err != nil { + return nil, err + } + + pubKeys := make([]string, 0, len(messages)) + for _, msg := range messages { + pubKey, errVerify := sr.verifyInvalidSigner(msg) + if errVerify != nil { + return nil, errVerify + } + + if len(pubKey) > 0 { + pubKeys = append(pubKeys, pubKey) + } + } + + return pubKeys, nil +} + +func (sr *subroundEndRound) verifyInvalidSigner(msg p2p.MessageP2P) (string, error) { + err := sr.MessageSigningHandler().Verify(msg) + if err != nil { + return "", err + } + + cnsMsg := &consensus.Message{} + err = sr.Marshalizer().Unmarshal(cnsMsg, msg.Data()) + if err != nil { + return "", err + } + + err = sr.SigningHandler().VerifySingleSignature(cnsMsg.PubKey, cnsMsg.BlockHeaderHash, cnsMsg.SignatureShare) + if err != nil { + log.Debug("verifyInvalidSigner: confirmed that node provided invalid signature", + "pubKey", cnsMsg.PubKey, + "blockHeaderHash", cnsMsg.BlockHeaderHash, + "error", err.Error(), + ) + sr.applyBlacklistOnNode(msg.Peer()) + + return string(cnsMsg.PubKey), nil + } + + return "", nil +} + +func (sr *subroundEndRound) applyBlacklistOnNode(peer core.PeerID) { + sr.PeerBlacklistHandler().BlacklistPeer(peer, common.InvalidSigningBlacklistDuration) +} + +// doEndRoundJob method does the job of the subround EndRound +func (sr *subroundEndRound) doEndRoundJob(_ context.Context) bool { + if check.IfNil(sr.GetHeader()) { + return false + } + + sr.mutProcessingEndRound.Lock() + defer sr.mutProcessingEndRound.Unlock() + + return sr.doEndRoundJobByNode() +} + +func (sr *subroundEndRound) commitBlock() error { + startTime := time.Now() + err := sr.BlockProcessor().CommitBlock(sr.GetHeader(), sr.GetBody()) + elapsedTime := time.Since(startTime) + if elapsedTime >= common.CommitMaxTime { + log.Warn("doEndRoundJobByNode.CommitBlock", "elapsed time", elapsedTime) + } else { + log.Debug("elapsed time to commit block", "time [s]", elapsedTime) + } + if err != nil { + log.Debug("doEndRoundJobByNode.CommitBlock", "error", err) + return err + } + + return nil +} + +func (sr *subroundEndRound) doEndRoundJobByNode() bool { + if sr.shouldSendProof() { + if !sr.waitForSignalSync() { + return false + } + + proofSent, err := sr.sendProof() + shouldWaitForMoreSignatures := errors.Is(err, spos.ErrInvalidNumSigShares) + // if not enough valid signatures were detected, wait a bit more + // either more signatures will be received, either proof from another participant + if shouldWaitForMoreSignatures { + return sr.doEndRoundJobByNode() + } + + if proofSent { + err := sr.prepareBroadcastBlockData() + log.LogIfError(err) + } + } + + return sr.finalizeConfirmedBlock() +} + +func (sr *subroundEndRound) prepareBroadcastBlockData() error { + miniBlocks, transactions, err := sr.BlockProcessor().MarshalizedDataToBroadcast(sr.GetHeader(), sr.GetBody()) + if err != nil { + return err + } + + getEquivalentProofSender := sr.getEquivalentProofSender() + go sr.BroadcastMessenger().PrepareBroadcastBlockDataWithEquivalentProofs(sr.GetHeader(), miniBlocks, transactions, []byte(getEquivalentProofSender)) + + return nil +} + +func (sr *subroundEndRound) waitForProof() bool { + shardID := sr.ShardCoordinator().SelfId() + headerHash := sr.GetData() + if sr.EquivalentProofsPool().HasProof(shardID, headerHash) { + return true + } + + ctx, cancel := context.WithTimeout(context.Background(), sr.RoundHandler().TimeDuration()) + defer cancel() + + for { + select { + case <-time.After(time.Millisecond): + if sr.EquivalentProofsPool().HasProof(shardID, headerHash) { + return true + } + case <-ctx.Done(): + return false + } + } +} + +func (sr *subroundEndRound) finalizeConfirmedBlock() bool { + if !sr.waitForProof() { + return false + } + + ok := sr.ScheduledProcessor().IsProcessedOKWithTimeout() + // placeholder for subroundEndRound.doEndRoundJobByLeader script + if !ok { + return false + } + + err := sr.commitBlock() + if err != nil { + return false + } + + sr.SetStatus(sr.Current(), spos.SsFinished) + + sr.worker.DisplayStatistics() + + log.Debug("step 3: Body and Header have been committed") + + msg := fmt.Sprintf("Added proposed block with nonce %d in blockchain", sr.GetHeader().GetNonce()) + log.Debug(display.Headline(msg, sr.SyncTimer().FormattedCurrentTime(), "+")) + + sr.updateMetricsForLeader() + + return true +} + +func (sr *subroundEndRound) sendProof() (bool, error) { + if !sr.shouldSendProof() { + return false, nil + } + + bitmap := sr.GenerateBitmap(bls.SrSignature) + err := sr.checkSignaturesValidity(bitmap) + if err != nil { + log.Debug("sendProof.checkSignaturesValidity", "error", err.Error()) + return false, err + } + + currentSender := sr.getEquivalentProofSender() + + // Aggregate signatures, handle invalid signers and send final info if needed + bitmap, sig, err := sr.aggregateSigsAndHandleInvalidSigners(bitmap, currentSender) + if err != nil { + log.Debug("sendProof.aggregateSigsAndHandleInvalidSigners", "error", err.Error()) + return false, err + } + + roundHandler := sr.RoundHandler() + if roundHandler.RemainingTime(roundHandler.TimeStamp(), roundHandler.TimeDuration()) < 0 { + log.Debug("sendProof: time is out -> cancel broadcasting final info and header", + "round time stamp", roundHandler.TimeStamp(), + "current time", time.Now()) + return false, ErrTimeOut + } + + // broadcast header proof + err = sr.createAndBroadcastProof(sig, bitmap, currentSender) + if err != nil && !errors.Is(err, ErrProofAlreadyPropagated) { + log.Warn("sendProof.createAndBroadcastProof", "error", err.Error()) + } + + proofSent := err == nil + return proofSent, err +} + +func (sr *subroundEndRound) shouldSendProof() bool { + if sr.EquivalentProofsPool().HasProof(sr.ShardCoordinator().SelfId(), sr.GetData()) { + log.Debug("shouldSendProof: equivalent message already processed") + return false + } + + return sr.IsSelfInConsensusGroup() +} + +func (sr *subroundEndRound) aggregateSigsAndHandleInvalidSigners(bitmap []byte, sender string) ([]byte, []byte, error) { + if sr.EquivalentProofsPool().HasProof(sr.ShardCoordinator().SelfId(), sr.GetData()) { + return nil, nil, ErrProofAlreadyPropagated + } + sig, err := sr.SigningHandler().AggregateSigs(bitmap, sr.GetHeader().GetEpoch()) + if err != nil { + log.Debug("doEndRoundJobByNode.AggregateSigs", "error", err.Error()) + + return sr.handleInvalidSignersOnAggSigFail(sender) + } + + err = sr.SigningHandler().SetAggregatedSig(sig) + if err != nil { + log.Debug("doEndRoundJobByNode.SetAggregatedSig", "error", err.Error()) + return nil, nil, err + } + + // the header (hash) verified here is with leader signature on it + err = sr.SigningHandler().Verify(sr.GetData(), bitmap, sr.GetHeader().GetEpoch()) + if err != nil { + log.Debug("doEndRoundJobByNode.Verify", "error", err.Error()) + + return sr.handleInvalidSignersOnAggSigFail(sender) + } + + return bitmap, sig, nil +} + +func (sr *subroundEndRound) checkGoRoutinesThrottler(ctx context.Context) error { + for { + if sr.signatureThrottler.CanProcess() { + break + } + + select { + case <-time.After(time.Millisecond): + continue + case <-ctx.Done(): + return spos.ErrTimeIsOut + } + } + return nil +} + +// verifySignature implements parallel signature verification +func (sr *subroundEndRound) verifySignature(i int, pk string, sigShare []byte) error { + err := sr.SigningHandler().VerifySignatureShare(uint16(i), sigShare, sr.GetData(), sr.GetHeader().GetEpoch()) + if err != nil { + log.Trace("VerifySignatureShare returned an error: ", "error", err) + errSetJob := sr.SetJobDone(pk, bls.SrSignature, false) + if errSetJob != nil { + return errSetJob + } + + decreaseFactor := -spos.ValidatorPeerHonestyIncreaseFactor + spos.ValidatorPeerHonestyDecreaseFactor + + sr.PeerHonestyHandler().ChangeScore( + pk, + spos.GetConsensusTopicID(sr.ShardCoordinator()), + decreaseFactor, + ) + return err + } + + log.Trace("verifyNodesOnAggSigVerificationFail: verifying signature share", "public key", pk) + + return nil +} + +func (sr *subroundEndRound) verifyNodesOnAggSigFail(ctx context.Context) ([]string, error) { + wg := &sync.WaitGroup{} + mutex := &sync.Mutex{} + invalidPubKeys := make([]string, 0) + pubKeys := sr.ConsensusGroup() + + if check.IfNil(sr.GetHeader()) { + return nil, spos.ErrNilHeader + } + + for i, pk := range pubKeys { + isJobDone, err := sr.JobDone(pk, bls.SrSignature) + if err != nil || !isJobDone { + continue + } + + sigShare, err := sr.SigningHandler().SignatureShare(uint16(i)) + if err != nil { + return nil, err + } + + err = sr.checkGoRoutinesThrottler(ctx) + if err != nil { + return nil, err + } + + sr.signatureThrottler.StartProcessing() + + wg.Add(1) + + go func(i int, pk string, sigShare []byte) { + defer func() { + sr.signatureThrottler.EndProcessing() + wg.Done() + }() + errSigVerification := sr.verifySignature(i, pk, sigShare) + if errSigVerification != nil { + mutex.Lock() + invalidPubKeys = append(invalidPubKeys, pk) + mutex.Unlock() + } + }(i, pk, sigShare) + } + wg.Wait() + + return invalidPubKeys, nil +} + +func (sr *subroundEndRound) getFullMessagesForInvalidSigners(invalidPubKeys []string) ([]byte, error) { + p2pMessages := make([]p2p.MessageP2P, 0) + + for _, pk := range invalidPubKeys { + p2pMsg, ok := sr.GetMessageWithSignature(pk) + if !ok { + log.Trace("message not found in state for invalid signer", "pubkey", pk) + continue + } + + p2pMessages = append(p2pMessages, p2pMsg) + } + + invalidSigners, err := sr.MessageSigningHandler().Serialize(p2pMessages) + if err != nil { + return nil, err + } + + return invalidSigners, nil +} + +func (sr *subroundEndRound) handleInvalidSignersOnAggSigFail(sender string) ([]byte, []byte, error) { + ctx, cancel := context.WithTimeout(context.Background(), sr.RoundHandler().TimeDuration()) + invalidPubKeys, err := sr.verifyNodesOnAggSigFail(ctx) + cancel() + if err != nil { + log.Debug("handleInvalidSignersOnAggSigFail.verifyNodesOnAggSigFail", "error", err.Error()) + return nil, nil, err + } + + invalidSigners, err := sr.getFullMessagesForInvalidSigners(invalidPubKeys) + if err != nil { + log.Debug("handleInvalidSignersOnAggSigFail.getFullMessagesForInvalidSigners", "error", err.Error()) + return nil, nil, err + } + + if sr.EquivalentProofsPool().HasProof(sr.ShardCoordinator().SelfId(), sr.GetData()) { + return nil, nil, ErrProofAlreadyPropagated + } + + if len(invalidSigners) > 0 { + sr.createAndBroadcastInvalidSigners(invalidSigners, invalidPubKeys, sender) + } + + bitmap, sig, err := sr.computeAggSigOnValidNodes() + if err != nil { + log.Debug("handleInvalidSignersOnAggSigFail.computeAggSigOnValidNodes", "error", err.Error()) + return nil, nil, err + } + + return bitmap, sig, nil +} + +func (sr *subroundEndRound) computeAggSigOnValidNodes() ([]byte, []byte, error) { + threshold := sr.Threshold(bls.SrSignature) + numValidSigShares := sr.ComputeSize(bls.SrSignature) + + if check.IfNil(sr.GetHeader()) { + return nil, nil, spos.ErrNilHeader + } + + if numValidSigShares < threshold { + return nil, nil, fmt.Errorf("%w: number of valid sig shares lower than threshold, numSigShares: %d, threshold: %d", + spos.ErrInvalidNumSigShares, numValidSigShares, threshold) + } + + bitmap := sr.GenerateBitmap(bls.SrSignature) + err := sr.checkSignaturesValidity(bitmap) + if err != nil { + return nil, nil, err + } + + sig, err := sr.SigningHandler().AggregateSigs(bitmap, sr.GetHeader().GetEpoch()) + if err != nil { + return nil, nil, err + } + + err = sr.SigningHandler().SetAggregatedSig(sig) + if err != nil { + return nil, nil, err + } + + log.Trace("computeAggSigOnValidNodes", + "bitmap", bitmap, + "threshold", threshold, + "numValidSigShares", numValidSigShares, + ) + + return bitmap, sig, nil +} + +func (sr *subroundEndRound) createAndBroadcastProof( + signature []byte, + bitmap []byte, + sender string, +) error { + if sr.EquivalentProofsPool().HasProof(sr.ShardCoordinator().SelfId(), sr.GetData()) { + // no need to broadcast a proof if already received and verified one + return ErrProofAlreadyPropagated + } + + headerProof := &block.HeaderProof{ + PubKeysBitmap: bitmap, + AggregatedSignature: signature, + HeaderHash: sr.GetData(), + HeaderEpoch: sr.GetHeader().GetEpoch(), + HeaderNonce: sr.GetHeader().GetNonce(), + HeaderShardId: sr.GetHeader().GetShardID(), + HeaderRound: sr.GetHeader().GetRound(), + IsStartOfEpoch: sr.GetHeader().IsStartOfEpochBlock(), + } + + err := sr.BroadcastMessenger().BroadcastEquivalentProof(headerProof, []byte(sender)) + if err != nil { + return err + } + + log.Debug("step 3: block header proof has been sent", + "PubKeysBitmap", bitmap, + "AggregateSignature", signature, + "proof sender", hex.EncodeToString([]byte(sender))) + + return nil +} + +func (sr *subroundEndRound) getEquivalentProofSender() string { + if sr.IsNodeInConsensusGroup(sr.SelfPubKey()) { + return sr.SelfPubKey() // single key mode + } + + return sr.getRandomManagedKeyProofSender() +} + +func (sr *subroundEndRound) getRandomManagedKeyProofSender() string { + // in multikey mode, we randomly select one managed key for the proof + consensusKeysManagedByCurrentNode := make([]string, 0) + for _, validator := range sr.ConsensusGroup() { + if !sr.IsKeyManagedBySelf([]byte(validator)) { + continue + } + + consensusKeysManagedByCurrentNode = append(consensusKeysManagedByCurrentNode, validator) + } + + if len(consensusKeysManagedByCurrentNode) == 0 { + return sr.SelfPubKey() // fallback return self pub key, should never happen + } + + randIdx := rand.Intn(len(consensusKeysManagedByCurrentNode)) + randManagedKey := consensusKeysManagedByCurrentNode[randIdx] + + return randManagedKey +} + +func (sr *subroundEndRound) createAndBroadcastInvalidSigners( + invalidSigners []byte, + invalidSignersPubKeys []string, + sender string, +) { + if !sr.ShouldConsiderSelfKeyInConsensus() && !sr.IsMultiKeyInConsensusGroup() { + return + } + + cnsMsg := consensus.NewConsensusMessage( + sr.GetData(), + nil, + nil, + nil, + []byte(sender), + nil, + int(bls.MtInvalidSigners), + sr.RoundHandler().Index(), + sr.ChainID(), + nil, + nil, + nil, + sr.GetAssociatedPid([]byte(sender)), + invalidSigners, + ) + + sr.InvalidSignersCache().AddInvalidSigners(sr.GetData(), invalidSigners, invalidSignersPubKeys) + + err := sr.BroadcastMessenger().BroadcastConsensusMessage(cnsMsg) + if err != nil { + log.Debug("doEndRoundJob.BroadcastConsensusMessage", "error", err.Error()) + return + } + + log.Debug("step 3: invalid signers info has been sent", "sender", hex.EncodeToString([]byte(sender))) +} + +func (sr *subroundEndRound) updateMetricsForLeader() { + if !sr.IsSelfLeader() { + return + } + + sr.appStatusHandler.Increment(common.MetricCountAcceptedBlocks) + sr.appStatusHandler.SetStringValue(common.MetricConsensusRoundState, + fmt.Sprintf("valid block produced in %f sec", time.Since(sr.RoundHandler().TimeStamp()).Seconds())) +} + +// doEndRoundConsensusCheck method checks if the consensus is achieved +func (sr *subroundEndRound) doEndRoundConsensusCheck() bool { + if sr.GetRoundCanceled() { + return false + } + + return sr.IsSubroundFinished(sr.Current()) +} + +func (sr *subroundEndRound) checkSignaturesValidity(bitmap []byte) error { + consensusGroup := sr.ConsensusGroup() + + shouldApplyFallbackValidation := sr.FallbackHeaderValidator().ShouldApplyFallbackValidation(sr.GetHeader()) + err := common.IsConsensusBitmapValid(log, consensusGroup, bitmap, shouldApplyFallbackValidation) + if err != nil { + return err + } + + signers := headerCheck.ComputeSignersPublicKeys(consensusGroup, bitmap) + for _, pubKey := range signers { + isSigJobDone, err := sr.JobDone(pubKey, bls.SrSignature) + if err != nil { + return err + } + + if !isSigJobDone { + return spos.ErrNilSignature + } + } + + return nil +} + +func (sr *subroundEndRound) isOutOfTime() bool { + startTime := sr.GetRoundTimeStamp() + maxTime := sr.RoundHandler().TimeDuration() * time.Duration(sr.processingThresholdPercentage) / 100 + if sr.RoundHandler().RemainingTime(startTime, maxTime) < 0 { + log.Debug("canceled round, time is out", + "round", sr.SyncTimer().FormattedCurrentTime(), sr.RoundHandler().Index(), + "subround", sr.Name()) + + sr.SetRoundCanceled(true) + return true + } + + return false +} + +func (sr *subroundEndRound) getMinConsensusGroupIndexOfManagedKeys() int { + minIdx := sr.ConsensusGroupSize() + + for idx, validator := range sr.ConsensusGroup() { + if !sr.IsKeyManagedBySelf([]byte(validator)) { + continue + } + + if idx < minIdx { + minIdx = idx + } + } + + return minIdx +} + +func (sr *subroundEndRound) waitForSignalSync() bool { + if sr.IsSubroundFinished(sr.Current()) { + return true + } + + if sr.checkReceivedSignatures() { + return true + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go sr.waitSignatures(ctx) + timerBetweenStatusChecks := time.NewTimer(timeBetweenSignaturesChecks) + + remainingSRTime := sr.remainingTime() + timeout := time.NewTimer(remainingSRTime) + for { + select { + case <-timerBetweenStatusChecks.C: + if sr.IsSubroundFinished(sr.Current()) { + log.Trace("subround already finished", "subround", sr.Name()) + return true + } + + if sr.checkReceivedSignatures() { + return true + } + timerBetweenStatusChecks.Reset(timeBetweenSignaturesChecks) + case <-timeout.C: + log.Debug("timeout while waiting for signatures or final info", "subround", sr.Name()) + return false + } + } +} + +func (sr *subroundEndRound) waitSignatures(ctx context.Context) { + remainingTime := sr.remainingTime() + if sr.IsSubroundFinished(sr.Current()) { + return + } + sr.SetWaitingAllSignaturesTimeOut(true) + + select { + case <-time.After(remainingTime): + case <-ctx.Done(): + } + sr.ConsensusChannel() <- true +} + +// maximum time to wait for signatures +func (sr *subroundEndRound) remainingTime() time.Duration { + startTime := sr.RoundHandler().TimeStamp() + maxTime := time.Duration(float64(sr.StartTime()) + float64(sr.EndTime()-sr.StartTime())*waitingAllSigsMaxTimeThreshold) + remainingTime := sr.RoundHandler().RemainingTime(startTime, maxTime) + + return remainingTime +} + +// receivedSignature method is called when a signature is received through the signature channel. +// If the signature is valid, then the jobDone map corresponding to the node which sent it, +// is set on true for the subround Signature +func (sr *subroundEndRound) receivedSignature(_ context.Context, cnsDta *consensus.Message) bool { + node := string(cnsDta.PubKey) + pkForLogs := core.GetTrimmedPk(hex.EncodeToString(cnsDta.PubKey)) + + if !sr.IsConsensusDataSet() { + return false + } + + if !sr.IsNodeInConsensusGroup(node) { + sr.PeerHonestyHandler().ChangeScore( + node, + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.ValidatorPeerHonestyDecreaseFactor, + ) + + return false + } + + if !sr.IsConsensusDataEqual(cnsDta.BlockHeaderHash) { + return false + } + + if !sr.CanProcessReceivedMessage(cnsDta, sr.RoundHandler().Index(), sr.Current()) { + return false + } + + index, err := sr.ConsensusGroupIndex(node) + if err != nil { + log.Debug("receivedSignature.ConsensusGroupIndex", + "node", pkForLogs, + "error", err.Error()) + return false + } + + err = sr.SigningHandler().StoreSignatureShare(uint16(index), cnsDta.SignatureShare) + if err != nil { + log.Debug("receivedSignature.StoreSignatureShare", + "node", pkForLogs, + "index", index, + "error", err.Error()) + return false + } + + err = sr.SetJobDone(node, bls.SrSignature, true) + if err != nil { + log.Debug("receivedSignature.SetJobDone", + "node", pkForLogs, + "subround", sr.Name(), + "error", err.Error()) + return false + } + + sr.PeerHonestyHandler().ChangeScore( + node, + spos.GetConsensusTopicID(sr.ShardCoordinator()), + spos.ValidatorPeerHonestyIncreaseFactor, + ) + + return true +} + +func (sr *subroundEndRound) checkReceivedSignatures() bool { + isTransitionBlock := common.IsEpochChangeBlockForFlagActivation(sr.GetHeader(), sr.EnableEpochsHandler(), common.AndromedaFlag) + + threshold := sr.Threshold(bls.SrSignature) + if isTransitionBlock { + threshold = core.GetPBFTThreshold(sr.ConsensusGroupSize()) + } + + if sr.FallbackHeaderValidator().ShouldApplyFallbackValidation(sr.GetHeader()) { + threshold = sr.FallbackThreshold(bls.SrSignature) + if isTransitionBlock { + threshold = core.GetPBFTFallbackThreshold(sr.ConsensusGroupSize()) + } + + log.Warn("subroundEndRound.checkReceivedSignatures: fallback validation has been applied", + "minimum number of signatures required", threshold, + "actual number of signatures received", sr.getNumOfSignaturesCollected(), + ) + } + + areSignaturesCollected, numSigs := sr.areSignaturesCollected(threshold) + areAllSignaturesCollected := numSigs == sr.ConsensusGroupSize() + + isSignatureCollectionDone := areAllSignaturesCollected || (areSignaturesCollected && sr.GetWaitingAllSignaturesTimeOut()) + + isSelfJobDone := sr.IsSelfJobDone(bls.SrSignature) + + shouldStopWaitingSignatures := isSelfJobDone && isSignatureCollectionDone + if shouldStopWaitingSignatures { + log.Debug("step 2: signatures collection done", + "subround", sr.Name(), + "signatures received", numSigs, + "total signatures", len(sr.ConsensusGroup()), + "threshold", threshold) + + return true + } + + return false +} + +func (sr *subroundEndRound) getNumOfSignaturesCollected() int { + n := 0 + + for i := 0; i < len(sr.ConsensusGroup()); i++ { + node := sr.ConsensusGroup()[i] + + isSignJobDone, err := sr.JobDone(node, bls.SrSignature) + if err != nil { + log.Debug("getNumOfSignaturesCollected.JobDone", + "node", node, + "subround", sr.Name(), + "error", err.Error()) + continue + } + + if isSignJobDone { + n++ + } + } + + return n +} + +// areSignaturesCollected method checks if the signatures received from the nodes, belonging to the current +// jobDone group, are more than the necessary given threshold +func (sr *subroundEndRound) areSignaturesCollected(threshold int) (bool, int) { + n := sr.getNumOfSignaturesCollected() + return n >= threshold, n +} + +// IsInterfaceNil returns true if there is no value under the interface +func (sr *subroundEndRound) IsInterfaceNil() bool { + return sr == nil +} diff --git a/consensus/spos/bls/v2/subroundEndRound_test.go b/consensus/spos/bls/v2/subroundEndRound_test.go new file mode 100644 index 00000000000..bdbc722f110 --- /dev/null +++ b/consensus/spos/bls/v2/subroundEndRound_test.go @@ -0,0 +1,2413 @@ +package v2_test + +import ( + "bytes" + "context" + "errors" + "math/big" + "sync" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/atomic" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-crypto-go/signing" + "github.com/multiversx/mx-chain-crypto-go/signing/mcl" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/dataRetriever/mock" + "github.com/multiversx/mx-chain-go/p2p" + "github.com/multiversx/mx-chain-go/p2p/factory" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" + "github.com/multiversx/mx-chain-go/testscommon" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +func initSubroundEndRoundWithContainer( + container *spos.ConsensusCore, + appStatusHandler core.AppStatusHandler, +) v2.SubroundEndRound { + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithNodesCoordinator(container.NodesCoordinator()) + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + sr.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + }) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + appStatusHandler, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + return srEndRound +} + +func initSubroundEndRoundWithContainerAndConsensusState( + container *spos.ConsensusCore, + appStatusHandler core.AppStatusHandler, + consensusState *spos.ConsensusState, + signatureThrottler core.Throttler, +) v2.SubroundEndRound { + ch := make(chan bool, 1) + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + sr.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + }) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + appStatusHandler, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + signatureThrottler, + ) + + return srEndRound +} + +func initSubroundEndRound(appStatusHandler core.AppStatusHandler) v2.SubroundEndRound { + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, appStatusHandler) + sr.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + }) + return sr +} + +func TestNewSubroundEndRound(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + t.Run("nil subround should error", func(t *testing.T) { + t.Parallel() + + srEndRound, err := v2.NewSubroundEndRound( + nil, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.Nil(t, srEndRound) + assert.Equal(t, spos.ErrNilSubround, err) + }) + t.Run("nil app status handler should error", func(t *testing.T) { + t.Parallel() + + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + nil, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.Nil(t, srEndRound) + assert.Equal(t, spos.ErrNilAppStatusHandler, err) + }) + t.Run("nil sent signatures tracker should error", func(t *testing.T) { + t.Parallel() + + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + nil, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.Nil(t, srEndRound) + assert.Equal(t, v2.ErrNilSentSignatureTracker, err) + }) + t.Run("nil worker should error", func(t *testing.T) { + t.Parallel() + + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + nil, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.Nil(t, srEndRound) + assert.Equal(t, spos.ErrNilWorker, err) + }) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilBlockChainShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetBlockchain(nil) + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, spos.ErrNilBlockChain, err) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilBlockProcessorShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetBlockProcessor(nil) + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, spos.ErrNilBlockProcessor, err) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilConsensusStateShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + sr.ConsensusStateHandler = nil + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, spos.ErrNilConsensusState, err) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilMultiSignerContainerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetMultiSignerContainer(nil) + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, spos.ErrNilMultiSignerContainer, err) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilRoundHandlerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetRoundHandler(nil) + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, spos.ErrNilRoundHandler, err) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilSyncTimerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetSyncTimer(nil) + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestSubroundEndRound_NewSubroundEndRoundNilThrottlerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + nil, + ) + + assert.True(t, check.IfNil(srEndRound)) + assert.Equal(t, err, spos.ErrNilThrottler) +} + +func TestSubroundEndRound_NewSubroundEndRoundShouldWork(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srEndRound, err := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + assert.False(t, check.IfNil(srEndRound)) + assert.Nil(t, err) +} + +func TestSubroundEndRound_DoEndRoundJobNilHeaderShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(nil) + + r := sr.DoEndRoundJob() + assert.False(t, r) +} + +func TestSubroundEndRound_DoEndRoundJobErrAggregatingSigShouldFail(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + signingHandler := &consensusMocks.SigningHandlerStub{ + AggregateSigsCalled: func(bitmap []byte, epoch uint32) ([]byte, error) { + return nil, crypto.ErrNilHasher + }, + } + container.SetSigningHandler(signingHandler) + + sr.SetHeader(&block.Header{}) + + sr.SetSelfPubKey("A") + + assert.True(t, sr.IsSelfLeader()) + r := sr.DoEndRoundJob() + assert.False(t, r) +} + +func TestSubroundEndRound_DoEndRoundJobErrCommitBlockShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetSelfPubKey("A") + + blProcMock := consensusMocks.InitBlockProcessorMock(container.Marshalizer()) + blProcMock.CommitBlockCalled = func( + header data.HeaderHandler, + body data.BodyHandler, + ) error { + return blockchain.ErrHeaderUnitNil + } + + container.SetBlockProcessor(blProcMock) + sr.SetHeader(&block.Header{}) + + r := sr.DoEndRoundJob() + assert.False(t, r) +} + +func TestSubroundEndRound_DoEndRoundJobErrTimeIsOutShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetSelfPubKey("A") + + remainingTime := -time.Millisecond + roundHandlerMock := &consensusMocks.RoundHandlerMock{ + RemainingTimeCalled: func(startTime time.Time, maxTime time.Duration) time.Duration { + return remainingTime + }, + } + + container.SetRoundHandler(roundHandlerMock) + sr.SetHeader(&block.Header{}) + + r := sr.DoEndRoundJob() + assert.False(t, r) +} + +func TestSubroundEndRound_DoEndRoundJobAllOK(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true + }, + }) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetSelfPubKey("A") + + sr.SetHeader(&block.Header{}) + + for _, participant := range sr.ConsensusGroup() { + _ = sr.SetJobDone(participant, bls.SrSignature, true) + } + + r := sr.DoEndRoundJob() + assert.True(t, r) +} + +func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetRoundCanceled(true) + + ok := sr.DoEndRoundConsensusCheck() + assert.False(t, ok) +} + +func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnTrueWhenRoundIsFinished(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + sr.SetStatus(bls.SrEndRound, spos.SsFinished) + + ok := sr.DoEndRoundConsensusCheck() + assert.True(t, ok) +} + +func TestSubroundEndRound_DoEndRoundConsensusCheckShouldReturnFalseWhenRoundIsNotFinished(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + + ok := sr.DoEndRoundConsensusCheck() + assert.False(t, ok) +} + +func TestSubroundEndRound_CheckSignaturesValidityShouldErrNilSignature(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + + bitmap := make([]byte, len(sr.ConsensusGroup())/8+1) + bitmap[0] = 0x77 + bitmap[1] = 0x01 + err := sr.CheckSignaturesValidity(bitmap) + + assert.Equal(t, spos.ErrNilSignature, err) +} + +func TestSubroundEndRound_CheckSignaturesValidityShouldReturnNil(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + + for _, pubKey := range sr.ConsensusGroup() { + _ = sr.SetJobDone(pubKey, bls.SrSignature, true) + } + + bitmap := make([]byte, len(sr.ConsensusGroup())/8+1) + bitmap[0] = 0x77 + bitmap[1] = 0x01 + + err := sr.CheckSignaturesValidity(bitmap) + require.Nil(t, err) +} + +func TestSubroundEndRound_CreateAndBroadcastProofShouldBeCalled(t *testing.T) { + t.Parallel() + + chanRcv := make(chan bool, 1) + leaderSigInHdr := []byte("leader sig") + container := consensusMocks.InitConsensusCore() + messenger := &consensusMocks.BroadcastMessengerMock{ + BroadcastEquivalentProofCalled: func(proof data.HeaderProofHandler, pkBytes []byte) error { + chanRcv <- true + return nil + }, + } + container.SetBroadcastMessenger(messenger) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{LeaderSignature: leaderSigInHdr}) + sr.CreateAndBroadcastProof([]byte("sig"), []byte("bitmap")) + + select { + case <-chanRcv: + case <-time.After(100 * time.Millisecond): + assert.Fail(t, "broadcast not called") + } +} + +func TestSubroundEndRound_ReceivedProof(t *testing.T) { + t.Parallel() + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + hdr := &block.Header{Nonce: 37} + container := consensusMocks.InitConsensusCore() + wasCommitBlockCalled := false + bp := &testscommon.BlockProcessorStub{ + CommitBlockCalled: func(header data.HeaderHandler, body data.BodyHandler) error { + wasCommitBlockCalled = true + return nil + }, + } + proofsPool := &dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true // skip signatures waiting + }, + } + container.SetBlockProcessor(bp) + container.SetEquivalentProofsPool(proofsPool) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(hdr) + sr.AddReceivedHeader(hdr) + + sr.SetStatus(2, spos.SsFinished) + sr.SetStatus(3, spos.SsNotFinished) + + headerHash := []byte("hash") + sr.SetData(headerHash) + proof := &block.HeaderProof{ + HeaderHash: headerHash, + } + sr.ReceivedProof(proof) + require.True(t, wasCommitBlockCalled) + }) + t.Run("should early return when job is already done", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + bp := &testscommon.BlockProcessorStub{ + CommitBlockCalled: func(header data.HeaderHandler, body data.BodyHandler) error { + require.Fail(t, "should have not been called") + return nil + }, + } + container.SetBlockProcessor(bp) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + _ = sr.SetJobDone(sr.SelfPubKey(), sr.Current(), true) + + sr.ReceivedProof(&block.HeaderProof{}) + }) + t.Run("should early return when header is nil", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + bp := &testscommon.BlockProcessorStub{ + CommitBlockCalled: func(header data.HeaderHandler, body data.BodyHandler) error { + require.Fail(t, "should have not been called") + return nil + }, + } + container.SetBlockProcessor(bp) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(nil) + + proof := &block.HeaderProof{} + + sr.ReceivedProof(proof) + }) + t.Run("should early return when header is not for current consensus", func(t *testing.T) { + t.Parallel() + + hdr := &block.Header{Nonce: 37} + container := consensusMocks.InitConsensusCore() + bp := &testscommon.BlockProcessorStub{ + CommitBlockCalled: func(header data.HeaderHandler, body data.BodyHandler) error { + require.Fail(t, "should have not been called") + return nil + }, + } + container.SetBlockProcessor(bp) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(hdr) + sr.AddReceivedHeader(hdr) + + proof := &block.HeaderProof{} + sr.ReceivedProof(proof) + }) + t.Run("should early return when proof is not valid", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{ + VerifyLeaderSignatureCalled: func(header data.HeaderHandler) error { + return errors.New("error") + }, + } + bp := &testscommon.BlockProcessorStub{ + CommitBlockCalled: func(header data.HeaderHandler, body data.BodyHandler) error { + require.Fail(t, "should have not been called") + return nil + }, + } + + container.SetHeaderSigVerifier(headerSigVerifier) + container.SetBlockProcessor(bp) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + proof := &block.HeaderProof{} + sr.ReceivedProof(proof) + }) + t.Run("should early return when consensus data is not set", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + bp := &testscommon.BlockProcessorStub{ + CommitBlockCalled: func(header data.HeaderHandler, body data.BodyHandler) error { + require.Fail(t, "should have not been called") + return nil + }, + } + container.SetBlockProcessor(bp) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetData(nil) + + proof := &block.HeaderProof{} + sr.ReceivedProof(proof) + }) + t.Run("should early return when sender is not in consensus group", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + bp := &testscommon.BlockProcessorStub{ + CommitBlockCalled: func(header data.HeaderHandler, body data.BodyHandler) error { + require.Fail(t, "should have not been called") + return nil + }, + } + container.SetBlockProcessor(bp) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + proof := &block.HeaderProof{} + sr.ReceivedProof(proof) + }) + t.Run("should early return when sender is self", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + bp := &testscommon.BlockProcessorStub{ + CommitBlockCalled: func(header data.HeaderHandler, body data.BodyHandler) error { + require.Fail(t, "should have not been called") + return nil + }, + } + container.SetBlockProcessor(bp) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetSelfPubKey("A") + + proof := &block.HeaderProof{} + sr.ReceivedProof(proof) + }) + t.Run("should early return when different data is received", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + bp := &testscommon.BlockProcessorStub{ + CommitBlockCalled: func(header data.HeaderHandler, body data.BodyHandler) error { + require.Fail(t, "should have not been called") + return nil + }, + } + container.SetBlockProcessor(bp) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetData([]byte("Y")) + + proof := &block.HeaderProof{} + sr.ReceivedProof(proof) + }) + t.Run("should early return when proof already received", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true + }, + }) + + bp := &testscommon.BlockProcessorStub{ + CommitBlockCalled: func(header data.HeaderHandler, body data.BodyHandler) error { + require.Fail(t, "should have not been called") + return nil + }, + } + container.SetBlockProcessor(bp) + + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusState() + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + }) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + proof := &block.HeaderProof{} + srEndRound.ReceivedProof(proof) + }) +} + +func TestSubroundEndRound_IsOutOfTimeShouldReturnFalse(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + + res := sr.IsOutOfTime() + assert.False(t, res) +} + +func TestSubroundEndRound_IsOutOfTimeShouldReturnTrue(t *testing.T) { + t.Parallel() + + // update roundHandler's mock, so it will calculate for real the duration + container := consensusMocks.InitConsensusCore() + roundHandler := consensusMocks.RoundHandlerMock{RemainingTimeCalled: func(startTime time.Time, maxTime time.Duration) time.Duration { + currentTime := time.Now() + elapsedTime := currentTime.Sub(startTime) + remainingTime := maxTime - elapsedTime + + return remainingTime + }} + container.SetRoundHandler(&roundHandler) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + sr.SetRoundTimeStamp(time.Now().AddDate(0, 0, -1)) + + res := sr.IsOutOfTime() + assert.True(t, res) +} + +func TestVerifyNodesOnAggSigVerificationFail(t *testing.T) { + t.Parallel() + + t.Run("fail to get signature share", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + signingHandler := &consensusMocks.SigningHandlerStub{ + SignatureShareCalled: func(index uint16) ([]byte, error) { + return nil, expectedErr + }, + } + + container.SetSigningHandler(signingHandler) + + sr.SetHeader(&block.Header{}) + leader, err := sr.GetLeader() + require.Nil(t, err) + _ = sr.SetJobDone(leader, bls.SrSignature, true) + + _, err = sr.VerifyNodesOnAggSigFail(context.TODO()) + require.Equal(t, expectedErr, err) + }) + + t.Run("fail to verify signature share, job done will be set to false", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + signingHandler := &consensusMocks.SigningHandlerStub{ + SignatureShareCalled: func(index uint16) ([]byte, error) { + return nil, nil + }, + VerifySignatureShareCalled: func(index uint16, sig, msg []byte, epoch uint32) error { + return expectedErr + }, + } + + sr.SetHeader(&block.Header{}) + leader, err := sr.GetLeader() + require.Nil(t, err) + _ = sr.SetJobDone(leader, bls.SrSignature, true) + container.SetSigningHandler(signingHandler) + _, err = sr.VerifyNodesOnAggSigFail(context.TODO()) + require.Nil(t, err) + + isJobDone, err := sr.JobDone(leader, bls.SrSignature) + require.Nil(t, err) + require.False(t, isJobDone) + }) + + t.Run("fail to verify signature share, an element will return an error on SignatureShare, should not panic", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + signingHandler := &consensusMocks.SigningHandlerStub{ + SignatureShareCalled: func(index uint16) ([]byte, error) { + if index < 8 { + return nil, nil + } + return nil, expectedErr + }, + VerifySignatureShareCalled: func(index uint16, sig, msg []byte, epoch uint32) error { + time.Sleep(100 * time.Millisecond) + return expectedErr + }, + VerifyCalled: func(msg, bitmap []byte, epoch uint32) error { + return nil + }, + } + container.SetSigningHandler(signingHandler) + + sr.SetHeader(&block.Header{}) + _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[1], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[2], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[3], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[4], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[5], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[6], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[7], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[8], bls.SrSignature, true) + go func() { + defer func() { + if r := recover(); r != nil { + t.Error("Should not panic") + } + }() + invalidSigners, err := sr.VerifyNodesOnAggSigFail(context.TODO()) + time.Sleep(200 * time.Millisecond) + require.Equal(t, err, expectedErr) + require.Nil(t, invalidSigners) + }() + time.Sleep(time.Second) + + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + signingHandler := &consensusMocks.SigningHandlerStub{ + SignatureShareCalled: func(index uint16) ([]byte, error) { + return nil, nil + }, + VerifySignatureShareCalled: func(index uint16, sig, msg []byte, epoch uint32) error { + return nil + }, + VerifyCalled: func(msg, bitmap []byte, epoch uint32) error { + return nil + }, + } + container.SetSigningHandler(signingHandler) + + sr.SetHeader(&block.Header{}) + _ = sr.SetJobDone(sr.ConsensusGroup()[0], bls.SrSignature, true) + _ = sr.SetJobDone(sr.ConsensusGroup()[1], bls.SrSignature, true) + invalidSigners, err := sr.VerifyNodesOnAggSigFail(context.TODO()) + require.Nil(t, err) + require.NotNil(t, invalidSigners) + }) +} + +func TestComputeAddSigOnValidNodes(t *testing.T) { + t.Parallel() + + t.Run("invalid number of valid sig shares", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{}) + sr.SetThreshold(bls.SrEndRound, 2) + + _, _, err := sr.ComputeAggSigOnValidNodes() + require.True(t, errors.Is(err, spos.ErrInvalidNumSigShares)) + }) + + t.Run("fail to created aggregated sig", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + signingHandler := &consensusMocks.SigningHandlerStub{ + AggregateSigsCalled: func(bitmap []byte, epoch uint32) ([]byte, error) { + return nil, expectedErr + }, + } + container.SetSigningHandler(signingHandler) + + sr.SetHeader(&block.Header{}) + for _, participant := range sr.ConsensusGroup() { + _ = sr.SetJobDone(participant, bls.SrSignature, true) + } + + _, _, err := sr.ComputeAggSigOnValidNodes() + require.Equal(t, expectedErr, err) + }) + + t.Run("fail to set aggregated sig", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + signingHandler := &consensusMocks.SigningHandlerStub{ + SetAggregatedSigCalled: func(_ []byte) error { + return expectedErr + }, + } + container.SetSigningHandler(signingHandler) + sr.SetHeader(&block.Header{}) + for _, participant := range sr.ConsensusGroup() { + _ = sr.SetJobDone(participant, bls.SrSignature, true) + } + + _, _, err := sr.ComputeAggSigOnValidNodes() + require.Equal(t, expectedErr, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{}) + for _, participant := range sr.ConsensusGroup() { + _ = sr.SetJobDone(participant, bls.SrSignature, true) + } + + bitmap, sig, err := sr.ComputeAggSigOnValidNodes() + require.NotNil(t, bitmap) + require.NotNil(t, sig) + require.Nil(t, err) + }) +} + +func TestSubroundEndRound_DoEndRoundJobByNode(t *testing.T) { + t.Parallel() + + t.Run("equivalent messages flag enabled and message already received", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + wasHasEquivalentProofCalled := false + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + wasHasEquivalentProofCalled = true + return true + }, + }) + + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusState() + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + }) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + srEndRound.SetThreshold(bls.SrSignature, 2) + + for _, participant := range srEndRound.ConsensusGroup() { + _ = srEndRound.SetJobDone(participant, bls.SrSignature, true) + } + + r := srEndRound.DoEndRoundJobByNode() + require.True(t, r) + require.True(t, wasHasEquivalentProofCalled) + }) + + t.Run("should work without equivalent messages flag active", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + numCalls := 0 + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + if numCalls <= 2 { + numCalls++ + return false + } + return true + }, + }) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + verifySigShareNumCalls := 0 + mutex := &sync.Mutex{} + verifyFirstCall := true + signingHandler := &consensusMocks.SigningHandlerStub{ + SignatureShareCalled: func(index uint16) ([]byte, error) { + return nil, nil + }, + VerifySignatureShareCalled: func(index uint16, sig, msg []byte, epoch uint32) error { + mutex.Lock() + defer mutex.Unlock() + if verifySigShareNumCalls == 0 { + verifySigShareNumCalls++ + return expectedErr + } + + verifySigShareNumCalls++ + return nil + }, + VerifyCalled: func(msg, bitmap []byte, epoch uint32) error { + mutex.Lock() + defer mutex.Unlock() + if verifyFirstCall { + verifyFirstCall = false + return expectedErr + } + + return nil + }, + } + + container.SetSigningHandler(signingHandler) + + sr.SetThreshold(bls.SrEndRound, 2) + + for _, participant := range sr.ConsensusGroup() { + _ = sr.SetJobDone(participant, bls.SrSignature, true) + } + + sr.SetHeader(&block.Header{}) + + r := sr.DoEndRoundJobByNode() + require.True(t, r) + + assert.False(t, verifyFirstCall) + assert.Equal(t, 9, verifySigShareNumCalls) + }) + t.Run("should work with equivalent messages flag active", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + container.SetBlockchain(&testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.HeaderV2{} + }, + }) + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true + }, + }) + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + wasIncrementHandlerCalled := false + wasSetStringValueHandlerCalled := false + sh := &statusHandler.AppStatusHandlerStub{ + IncrementHandler: func(key string) { + require.Equal(t, common.MetricCountAcceptedBlocks, key) + wasIncrementHandlerCalled = true + }, + SetStringValueHandler: func(key string, value string) { + require.Equal(t, common.MetricConsensusRoundState, key) + wasSetStringValueHandlerCalled = true + }, + } + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusState() + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + sh, + ) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + sh, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + srEndRound.SetThreshold(bls.SrEndRound, 2) + + for _, participant := range srEndRound.ConsensusGroup() { + _ = srEndRound.SetJobDone(participant, bls.SrSignature, true) + } + + srEndRound.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + ScheduledRootHash: []byte("sch root hash"), + ScheduledAccumulatedFees: big.NewInt(0), + ScheduledDeveloperFees: big.NewInt(0), + }) + + sr.SetLeader(sr.SelfPubKey()) + + r := srEndRound.DoEndRoundJobByNode() + require.True(t, r) + + require.True(t, wasIncrementHandlerCalled) + require.True(t, wasSetStringValueHandlerCalled) + }) + t.Run("invalid signers should wait for more signatures then work", func(t *testing.T) { + t.Parallel() + + chanSendNewSig := make(chan bool) + container := consensusMocks.InitConsensusCore() + shouldNotFailAnymore := atomic.Flag{} + signingHandler := &consensusMocks.SigningHandlerStub{ + VerifySignatureShareCalled: func(index uint16, sig []byte, msg []byte, epoch uint32) error { + if index == 3 { + return expectedErr + } + return nil + }, + AggregateSigsCalled: func(bitmap []byte, epoch uint32) ([]byte, error) { + if !shouldNotFailAnymore.IsSet() { + return nil, expectedErr // force invalid signers on first aggregation + } + + return []byte("sig"), nil + }, + } + container.SetSigningHandler(signingHandler) + container.SetBlockchain(&testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.HeaderV2{} + }, + }) + cntHasProof := 0 + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + cntHasProof++ + // second check for proof should be after recursive call + if cntHasProof == 3 { + chanSendNewSig <- true + } + return shouldNotFailAnymore.IsSet() + }, + }) + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusState() + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + consensusSize := sr.ConsensusGroupSize() + threshold := 2*consensusSize/3 + 1 + srEndRound.SetThreshold(bls.SrSignature, threshold) + + for i := 0; i < threshold; i++ { + participant := srEndRound.ConsensusGroup()[i] + _ = srEndRound.SetJobDone(participant, bls.SrSignature, true) + } + + srEndRound.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + ScheduledRootHash: []byte("sch root hash"), + ScheduledAccumulatedFees: big.NewInt(0), + ScheduledDeveloperFees: big.NewInt(0), + }) + + go func() { + for { + select { + case <-chanSendNewSig: + // add one more valid signature and avoid further errors + participant := srEndRound.ConsensusGroup()[threshold] + _ = srEndRound.SetJobDone(participant, bls.SrSignature, true) + shouldNotFailAnymore.SetValue(true) + return + case <-time.After(roundTimeDuration): + require.Fail(t, "should have not passed all time") + return + } + } + }() + + r := srEndRound.DoEndRoundJobByNode() + require.True(t, r) + }) +} + +func TestSubroundEndRound_ReceivedInvalidSignersInfo(t *testing.T) { + t.Parallel() + + t.Run("consensus data is not set", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.ConsensusStateHandler.SetData(nil) + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + t.Run("consensus header is not set", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(nil) + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + t.Run("received message node is not leader in current round", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("other node"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + + t.Run("received message from self leader should return false", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetSelfPubKey("A") + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + + t.Run("received message from self multikey leader should return false", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return string(pkBytes) == "A" + }, + } + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + srEndRound.SetSelfPubKey("A") + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + } + + res := srEndRound.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + + t.Run("received hash does not match the hash from current consensus state", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("Y"), + PubKey: []byte("A"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + t.Run("process received message verification failed, different round index", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + RoundIndex: 1, + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + t.Run("empty invalid signers", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + InvalidSigners: []byte{}, + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + t.Run("invalid signers cache already has this message", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + invalidSignersCache := &consensusMocks.InvalidSignersCacheMock{ + CheckKnownInvalidSignersCalled: func(headerHash []byte, invalidSigners []byte) bool { + return true + }, + } + container.SetInvalidSignersCache(invalidSignersCache) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + InvalidSigners: []byte("invalidSignersData"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + t.Run("invalid signers data", func(t *testing.T) { + t.Parallel() + + messageSigningHandler := &mock.MessageSigningHandlerStub{ + DeserializeCalled: func(messagesBytes []byte) ([]p2p.MessageP2P, error) { + return nil, expectedErr + }, + } + + container := consensusMocks.InitConsensusCore() + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + InvalidSigners: []byte("invalid data"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.False(t, res) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + wasAddInvalidSignersCalled := false + invalidSignersCache := &consensusMocks.InvalidSignersCacheMock{ + AddInvalidSignersCalled: func(headerHash []byte, invalidSigners []byte, invalidPublicKeys []string) { + wasAddInvalidSignersCalled = true + }, + } + container.SetInvalidSignersCache(invalidSignersCache) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.HeaderV2{ + Header: createDefaultHeader(), + }) + cnsData := consensus.Message{ + BlockHeaderHash: []byte("X"), + PubKey: []byte("A"), + InvalidSigners: []byte("invalidSignersData"), + } + + res := sr.ReceivedInvalidSignersInfo(&cnsData) + assert.True(t, res) + require.True(t, wasAddInvalidSignersCalled) + }) +} + +func TestVerifyInvalidSigners(t *testing.T) { + t.Parallel() + + t.Run("failed to deserialize invalidSigners field, should error", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + messageSigningHandler := &mock.MessageSigningHandlerStub{ + DeserializeCalled: func(messagesBytes []byte) ([]p2p.MessageP2P, error) { + return nil, expectedErr + }, + } + + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + _, err := sr.VerifyInvalidSigners([]byte{}) + require.Equal(t, expectedErr, err) + }) + + t.Run("failed to verify low level p2p message, should error", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + invalidSigners := []p2p.MessageP2P{&factory.Message{ + FromField: []byte("from"), + }} + invalidSignersBytes, _ := container.Marshalizer().Marshal(invalidSigners) + + messageSigningHandler := &mock.MessageSigningHandlerStub{ + DeserializeCalled: func(messagesBytes []byte) ([]p2p.MessageP2P, error) { + require.Equal(t, invalidSignersBytes, messagesBytes) + return invalidSigners, nil + }, + VerifyCalled: func(message p2p.MessageP2P) error { + return expectedErr + }, + } + + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + _, err := sr.VerifyInvalidSigners(invalidSignersBytes) + require.Equal(t, expectedErr, err) + }) + + t.Run("failed to verify signature share", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + pubKey := []byte("A") // it's in consensus + + consensusMsg := &consensus.Message{ + PubKey: pubKey, + } + consensusMsgBytes, _ := container.Marshalizer().Marshal(consensusMsg) + + invalidSigners := []p2p.MessageP2P{&factory.Message{ + FromField: []byte("from"), + DataField: consensusMsgBytes, + }} + invalidSignersBytes, _ := container.Marshalizer().Marshal(invalidSigners) + + messageSigningHandler := &mock.MessageSigningHandlerStub{ + DeserializeCalled: func(messagesBytes []byte) ([]p2p.MessageP2P, error) { + require.Equal(t, invalidSignersBytes, messagesBytes) + return invalidSigners, nil + }, + } + + wasCalled := false + signingHandler := &consensusMocks.SigningHandlerStub{ + VerifySingleSignatureCalled: func(publicKeyBytes []byte, message []byte, signature []byte) error { + wasCalled = true + return errors.New("expected err") + }, + } + + container.SetSigningHandler(signingHandler) + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + _, err := sr.VerifyInvalidSigners(invalidSignersBytes) + require.Nil(t, err) + require.True(t, wasCalled) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + pubKey := []byte("A") // it's in consensus + + consensusMsg := &consensus.Message{ + PubKey: pubKey, + } + consensusMsgBytes, _ := container.Marshalizer().Marshal(consensusMsg) + + invalidSigners := []p2p.MessageP2P{&factory.Message{ + FromField: []byte("from"), + DataField: consensusMsgBytes, + }} + invalidSignersBytes, _ := container.Marshalizer().Marshal(invalidSigners) + + messageSigningHandler := &mock.MessageSignerMock{} + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + _, err := sr.VerifyInvalidSigners(invalidSignersBytes) + require.Nil(t, err) + }) +} + +func TestSubroundEndRound_CreateAndBroadcastInvalidSigners(t *testing.T) { + t.Parallel() + + t.Run("redundancy node should not send while main is active", func(t *testing.T) { + t.Parallel() + + expectedInvalidSigners := []byte("invalid signers") + + container := consensusMocks.InitConsensusCore() + nodeRedundancy := &mock.NodeRedundancyHandlerStub{ + IsRedundancyNodeCalled: func() bool { + return true + }, + IsMainMachineActiveCalled: func() bool { + return true + }, + } + container.SetNodeRedundancyHandler(nodeRedundancy) + messenger := &consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + assert.Fail(t, "should have not been called") + return nil + }, + } + container.SetBroadcastMessenger(messenger) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + sr.CreateAndBroadcastInvalidSigners(expectedInvalidSigners) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + wg := &sync.WaitGroup{} + wg.Add(1) + + expectedInvalidSigners := []byte("invalid signers") + + wasBroadcastConsensusMessageCalled := false + container := consensusMocks.InitConsensusCore() + messenger := &consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + assert.Equal(t, expectedInvalidSigners, message.InvalidSigners) + wasBroadcastConsensusMessageCalled = true + wg.Done() + return nil + }, + } + container.SetBroadcastMessenger(messenger) + + wasAddInvalidSignersCalled := false + invalidSignersCache := &consensusMocks.InvalidSignersCacheMock{ + AddInvalidSignersCalled: func(headerHash []byte, invalidSigners []byte, invalidPublicKeys []string) { + wasAddInvalidSignersCalled = true + }, + } + container.SetInvalidSignersCache(invalidSignersCache) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetSelfPubKey("A") + + sr.CreateAndBroadcastInvalidSigners(expectedInvalidSigners) + + wg.Wait() + + require.True(t, wasBroadcastConsensusMessageCalled) + require.True(t, wasAddInvalidSignersCalled) + }) +} + +func TestGetFullMessagesForInvalidSigners(t *testing.T) { + t.Parallel() + + t.Run("empty p2p messages slice if not in state", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + messageSigningHandler := &mock.MessageSigningHandlerStub{ + SerializeCalled: func(messages []p2p.MessageP2P) ([]byte, error) { + require.Equal(t, 0, len(messages)) + + return []byte{}, nil + }, + } + + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + invalidSigners := []string{"B", "C"} + + invalidSignersBytes, err := sr.GetFullMessagesForInvalidSigners(invalidSigners) + require.Nil(t, err) + require.Equal(t, []byte{}, invalidSignersBytes) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + expectedInvalidSigners := []byte("expectedInvalidSigners") + + messageSigningHandler := &mock.MessageSigningHandlerStub{ + SerializeCalled: func(messages []p2p.MessageP2P) ([]byte, error) { + require.Equal(t, 2, len(messages)) + + return expectedInvalidSigners, nil + }, + } + + container.SetMessageSigningHandler(messageSigningHandler) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.AddMessageWithSignature("B", &p2pmocks.P2PMessageMock{}) + sr.AddMessageWithSignature("C", &p2pmocks.P2PMessageMock{}) + + invalidSigners := []string{"B", "C"} + + invalidSignersBytes, err := sr.GetFullMessagesForInvalidSigners(invalidSigners) + require.Nil(t, err) + require.Equal(t, expectedInvalidSigners, invalidSignersBytes) + }) +} + +func TestSubroundEndRound_getMinConsensusGroupIndexOfManagedKeys(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{} + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + sr, _ := spos.NewSubround( + bls.SrSignature, + bls.SrEndRound, + -1, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(END_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srEndRound, _ := v2.NewSubroundEndRound( + sr, + v2.ProcessingThresholdPercent, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMocks.ThrottlerStub{}, + ) + + t.Run("no managed keys from consensus group", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return false + } + + assert.Equal(t, 9, srEndRound.GetMinConsensusGroupIndexOfManagedKeys()) + }) + t.Run("first managed key in consensus group should return 0", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte("A"), pkBytes) + } + + assert.Equal(t, 0, srEndRound.GetMinConsensusGroupIndexOfManagedKeys()) + }) + t.Run("third managed key in consensus group should return 2", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte("C"), pkBytes) + } + + assert.Equal(t, 2, srEndRound.GetMinConsensusGroupIndexOfManagedKeys()) + }) + t.Run("last managed key in consensus group should return 8", func(t *testing.T) { + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return bytes.Equal([]byte("I"), pkBytes) + } + + assert.Equal(t, 8, srEndRound.GetMinConsensusGroupIndexOfManagedKeys()) + }) +} + +func TestSubroundEndRound_ReceivedSignature(t *testing.T) { + t.Parallel() + + sr := initSubroundEndRound(&statusHandler.AppStatusHandlerStub{}) + signature := []byte("signature") + cnsMsg := consensus.NewConsensusMessage( + sr.GetData(), + signature, + nil, + nil, + []byte(sr.ConsensusGroup()[1]), + []byte("sig"), + int(bls.MtSignature), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + + sr.SetHeader(&block.Header{}) + sr.SetData(nil) + r := sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + sr.SetData([]byte("Y")) + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + sr.SetData([]byte("X")) + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + leader, err := sr.GetLeader() + assert.Nil(t, err) + + sr.SetSelfPubKey(leader) + + cnsMsg.PubKey = []byte("X") + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + cnsMsg.PubKey = []byte(sr.ConsensusGroup()[1]) + maxCount := len(sr.ConsensusGroup()) * 2 / 3 + count := 0 + for i := 0; i < len(sr.ConsensusGroup()); i++ { + if sr.ConsensusGroup()[i] != string(cnsMsg.PubKey) { + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) + count++ + if count == maxCount { + break + } + } + } + r = sr.ReceivedSignature(cnsMsg) + assert.True(t, r) +} + +func TestSubroundEndRound_ReceivedSignatureStoreShareFailed(t *testing.T) { + t.Parallel() + + errStore := errors.New("signature share store failed") + storeSigShareCalled := false + signingHandler := &consensusMocks.SigningHandlerStub{ + VerifySignatureShareCalled: func(index uint16, sig, msg []byte, epoch uint32) error { + return nil + }, + StoreSignatureShareCalled: func(index uint16, sig []byte) error { + storeSigShareCalled = true + return errStore + }, + } + + container := consensusMocks.InitConsensusCore() + container.SetSigningHandler(signingHandler) + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + sr.SetHeader(&block.Header{}) + + signature := []byte("signature") + cnsMsg := consensus.NewConsensusMessage( + sr.GetData(), + signature, + nil, + nil, + []byte(sr.ConsensusGroup()[1]), + []byte("sig"), + int(bls.MtSignature), + 0, + chainID, + nil, + nil, + nil, + currentPid, + nil, + ) + + sr.SetData(nil) + r := sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + sr.SetData([]byte("Y")) + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + sr.SetData([]byte("X")) + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + + cnsMsg.PubKey = []byte("X") + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + + cnsMsg.PubKey = []byte(sr.ConsensusGroup()[1]) + maxCount := len(sr.ConsensusGroup()) * 2 / 3 + count := 0 + for i := 0; i < len(sr.ConsensusGroup()); i++ { + if sr.ConsensusGroup()[i] != string(cnsMsg.PubKey) { + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) + count++ + if count == maxCount { + break + } + } + } + r = sr.ReceivedSignature(cnsMsg) + assert.False(t, r) + assert.True(t, storeSigShareCalled) +} + +func TestSubroundEndRound_WaitForProof(t *testing.T) { + t.Parallel() + + t.Run("should return true if there is proof", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true + }, + }) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + ok := sr.WaitForProof() + require.True(t, ok) + }) + + t.Run("should return true after waiting and finding proof", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + numCalls := 0 + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + if numCalls < 2 { + numCalls++ + return false + } + + return true + }, + }) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + ok := sr.WaitForProof() + require.True(t, ok) + + require.Equal(t, 2, numCalls) + }) + + t.Run("should return false on timeout", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return false + }, + }) + + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + ok := sr.WaitForProof() + require.False(t, ok) + }) +} + +func TestSubroundEndRound_GetEquivalentProofSender(t *testing.T) { + t.Parallel() + + t.Run("for single key, return self pubkey", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + selfKey := sr.SelfPubKey() + + sender := sr.GetEquivalentProofSender() + require.Equal(t, selfKey, sender) + }) + + t.Run("for multi key, return random key", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + suite := mcl.NewSuiteBLS12() + kg := signing.NewKeyGenerator(suite) + + mapKeys := generateKeyPairs(kg) + + pubKeys := make([]string, 0) + for pubKey := range mapKeys { + pubKeys = append(pubKeys, pubKey) + } + + nc := &shardingMocks.NodesCoordinatorMock{ + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + defaultSelectionChances := uint32(1) + leader := shardingMocks.NewValidatorMock([]byte(pubKeys[0]), 1, defaultSelectionChances) + return leader, []nodesCoordinator.Validator{ + leader, + shardingMocks.NewValidatorMock([]byte(pubKeys[1]), 1, defaultSelectionChances), + shardingMocks.NewValidatorMock([]byte(pubKeys[2]), 1, defaultSelectionChances), + shardingMocks.NewValidatorMock([]byte(pubKeys[3]), 1, defaultSelectionChances), + shardingMocks.NewValidatorMock([]byte(pubKeys[4]), 1, defaultSelectionChances), + shardingMocks.NewValidatorMock([]byte(pubKeys[5]), 1, defaultSelectionChances), + shardingMocks.NewValidatorMock([]byte(pubKeys[6]), 1, defaultSelectionChances), + shardingMocks.NewValidatorMock([]byte(pubKeys[7]), 1, defaultSelectionChances), + shardingMocks.NewValidatorMock([]byte(pubKeys[8]), 1, defaultSelectionChances), + }, nil + }, + } + container.SetNodesCoordinator(nc) + + keysHandlerMock := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + _, ok := mapKeys[string(pkBytes)] + return ok + }, + } + + consensusState := initializers.InitConsensusStateWithArgs(keysHandlerMock, mapKeys) + sr := initSubroundEndRoundWithContainerAndConsensusState(container, &statusHandler.AppStatusHandlerStub{}, consensusState, &dataRetrieverMocks.ThrottlerStub{}) + sr.SetSelfPubKey("not in consensus") + + selfKey := sr.SelfPubKey() + + sender := sr.GetEquivalentProofSender() + assert.NotEqual(t, selfKey, sender) + }) +} + +func TestSubroundEndRound_SendProof(t *testing.T) { + t.Parallel() + + t.Run("existing proof should not send again", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + proofsPool := &dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true + }, + } + container.SetEquivalentProofsPool(proofsPool) + bm := &consensusMocks.BroadcastMessengerMock{ + BroadcastEquivalentProofCalled: func(proof data.HeaderProofHandler, pkBytes []byte) error { + require.Fail(t, "should have not been called") + return nil + }, + } + container.SetBroadcastMessenger(bm) + wasSent, err := sr.SendProof() + require.False(t, wasSent) + require.NoError(t, err) + }) + t.Run("not enough signatures should not send proof", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + bm := &consensusMocks.BroadcastMessengerMock{ + BroadcastEquivalentProofCalled: func(proof data.HeaderProofHandler, pkBytes []byte) error { + require.Fail(t, "should have not been called") + return nil + }, + } + container.SetBroadcastMessenger(bm) + wasSent, err := sr.SendProof() + require.False(t, wasSent) + require.Error(t, err) + }) + t.Run("signature aggregation failure should not send proof", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + bm := &consensusMocks.BroadcastMessengerMock{ + BroadcastEquivalentProofCalled: func(proof data.HeaderProofHandler, pkBytes []byte) error { + require.Fail(t, "should have not been called") + return nil + }, + } + container.SetBroadcastMessenger(bm) + signingHandler := &consensusMocks.SigningHandlerStub{ + AggregateSigsCalled: func(bitmap []byte, epoch uint32) ([]byte, error) { + return nil, expectedErr + }, + } + container.SetSigningHandler(signingHandler) + + for _, pubKey := range sr.ConsensusGroup() { + _ = sr.SetJobDone(pubKey, bls.SrSignature, true) + } + + wasSent, err := sr.SendProof() + require.False(t, wasSent) + require.Equal(t, expectedErr, err) + }) + t.Run("no time left should not send proof", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + bm := &consensusMocks.BroadcastMessengerMock{ + BroadcastEquivalentProofCalled: func(proof data.HeaderProofHandler, pkBytes []byte) error { + require.Fail(t, "should have not been called") + return nil + }, + } + container.SetBroadcastMessenger(bm) + roundHandler := &consensusMocks.RoundHandlerMock{ + RemainingTimeCalled: func(startTime time.Time, maxTime time.Duration) time.Duration { + return -1 // no time left + }, + } + container.SetRoundHandler(roundHandler) + + for _, pubKey := range sr.ConsensusGroup() { + _ = sr.SetJobDone(pubKey, bls.SrSignature, true) + } + + wasSent, err := sr.SendProof() + require.False(t, wasSent) + require.Equal(t, v2.ErrTimeOut, err) + }) + t.Run("broadcast failure should not send proof", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + bm := &consensusMocks.BroadcastMessengerMock{ + BroadcastEquivalentProofCalled: func(proof data.HeaderProofHandler, pkBytes []byte) error { + return expectedErr + }, + } + container.SetBroadcastMessenger(bm) + + for _, pubKey := range sr.ConsensusGroup() { + _ = sr.SetJobDone(pubKey, bls.SrSignature, true) + } + + wasSent, err := sr.SendProof() + require.False(t, wasSent) + require.Equal(t, expectedErr, err) + }) + t.Run("should send", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundEndRoundWithContainer(container, &statusHandler.AppStatusHandlerStub{}) + + wasBroadcastEquivalentProofCalled := false + bm := &consensusMocks.BroadcastMessengerMock{ + BroadcastEquivalentProofCalled: func(proof data.HeaderProofHandler, pkBytes []byte) error { + wasBroadcastEquivalentProofCalled = true + return nil + }, + } + container.SetBroadcastMessenger(bm) + + for _, pubKey := range sr.ConsensusGroup() { + _ = sr.SetJobDone(pubKey, bls.SrSignature, true) + } + + wasSent, err := sr.SendProof() + require.True(t, wasSent) + require.NoError(t, err) + require.True(t, wasBroadcastEquivalentProofCalled) + }) +} diff --git a/consensus/spos/bls/v2/subroundSignature.go b/consensus/spos/bls/v2/subroundSignature.go new file mode 100644 index 00000000000..d6cb7fddddc --- /dev/null +++ b/consensus/spos/bls/v2/subroundSignature.go @@ -0,0 +1,315 @@ +package v2 + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + atomicCore "github.com/multiversx/mx-chain-core-go/core/atomic" + "github.com/multiversx/mx-chain-core-go/core/check" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" +) + +const timeSpentBetweenChecks = time.Millisecond + +type subroundSignature struct { + *spos.Subround + appStatusHandler core.AppStatusHandler + sentSignatureTracker spos.SentSignaturesTracker + signatureThrottler core.Throttler +} + +// NewSubroundSignature creates a subroundSignature object +func NewSubroundSignature( + baseSubround *spos.Subround, + appStatusHandler core.AppStatusHandler, + sentSignatureTracker spos.SentSignaturesTracker, + worker spos.WorkerHandler, + signatureThrottler core.Throttler, +) (*subroundSignature, error) { + err := checkNewSubroundSignatureParams( + baseSubround, + ) + if err != nil { + return nil, err + } + if check.IfNil(appStatusHandler) { + return nil, spos.ErrNilAppStatusHandler + } + if check.IfNil(sentSignatureTracker) { + return nil, ErrNilSentSignatureTracker + } + if check.IfNil(worker) { + return nil, spos.ErrNilWorker + } + if check.IfNil(signatureThrottler) { + return nil, spos.ErrNilThrottler + } + + srSignature := subroundSignature{ + Subround: baseSubround, + appStatusHandler: appStatusHandler, + sentSignatureTracker: sentSignatureTracker, + signatureThrottler: signatureThrottler, + } + srSignature.Job = srSignature.doSignatureJob + srSignature.Check = srSignature.doSignatureConsensusCheck + srSignature.Extend = worker.Extend + + return &srSignature, nil +} + +func checkNewSubroundSignatureParams( + baseSubround *spos.Subround, +) error { + if baseSubround == nil { + return spos.ErrNilSubround + } + if check.IfNil(baseSubround.ConsensusStateHandler) { + return spos.ErrNilConsensusState + } + + err := spos.ValidateConsensusCore(baseSubround.ConsensusCoreHandler) + + return err +} + +// doSignatureJob method does the job of the subround Signature +func (sr *subroundSignature) doSignatureJob(ctx context.Context) bool { + if !sr.CanDoSubroundJob(sr.Current()) { + return false + } + if check.IfNil(sr.GetHeader()) { + log.Error("doSignatureJob", "error", spos.ErrNilHeader) + return false + } + + proofAlreadyReceived := sr.EquivalentProofsPool().HasProof(sr.ShardCoordinator().SelfId(), sr.GetData()) + if proofAlreadyReceived { + sr.SetStatus(sr.Current(), spos.SsFinished) + log.Debug("step 2: subround has been finished, proof already received", + "subround", sr.Name()) + + return true + } + + isSelfSingleKeyInConsensusGroup := sr.IsNodeInConsensusGroup(sr.SelfPubKey()) && sr.ShouldConsiderSelfKeyInConsensus() + if isSelfSingleKeyInConsensusGroup { + if !sr.doSignatureJobForSingleKey() { + return false + } + } + + if !sr.doSignatureJobForManagedKeys(ctx) { + return false + } + + sr.SetStatus(sr.Current(), spos.SsFinished) + log.Debug("step 2: subround has been finished", + "subround", sr.Name()) + + return true +} + +func (sr *subroundSignature) createAndSendSignatureMessage(signatureShare []byte, pkBytes []byte) bool { + cnsMsg := consensus.NewConsensusMessage( + sr.GetData(), + signatureShare, + nil, + nil, + pkBytes, + nil, + int(bls.MtSignature), + sr.RoundHandler().Index(), + sr.ChainID(), + nil, + nil, + nil, + sr.GetAssociatedPid(pkBytes), + nil, + ) + + err := sr.BroadcastMessenger().BroadcastConsensusMessage(cnsMsg) + if err != nil { + log.Debug("createAndSendSignatureMessage.BroadcastConsensusMessage", + "error", err.Error(), "pk", pkBytes) + return false + } + + log.Debug("step 2: signature has been sent", "pk", pkBytes) + + return true +} + +func (sr *subroundSignature) completeSignatureSubRound(pk string) bool { + err := sr.SetJobDone(pk, sr.Current(), true) + if err != nil { + log.Debug("doSignatureJob.SetSelfJobDone", + "subround", sr.Name(), + "error", err.Error(), + "pk", []byte(pk), + ) + return false + } + + return true +} + +// doSignatureConsensusCheck method checks if the consensus in the subround Signature is achieved +func (sr *subroundSignature) doSignatureConsensusCheck() bool { + if sr.GetRoundCanceled() { + return false + } + + if sr.IsSubroundFinished(sr.Current()) { + return true + } + + if check.IfNil(sr.GetHeader()) { + return false + } + + isSelfInConsensusGroup := sr.IsSelfInConsensusGroup() + if !isSelfInConsensusGroup { + log.Debug("step 2: subround has been finished", + "subround", sr.Name()) + sr.SetStatus(sr.Current(), spos.SsFinished) + + return true + } + + if sr.IsSelfJobDone(sr.Current()) { + log.Debug("step 2: subround has been finished", + "subround", sr.Name()) + sr.SetStatus(sr.Current(), spos.SsFinished) + sr.appStatusHandler.SetStringValue(common.MetricConsensusRoundState, "signed") + + return true + } + + return false +} + +func (sr *subroundSignature) doSignatureJobForManagedKeys(ctx context.Context) bool { + numMultiKeysSignaturesSent := int32(0) + sentSigForAllKeys := atomicCore.Flag{} + sentSigForAllKeys.SetValue(true) + + wg := sync.WaitGroup{} + + for idx, pk := range sr.ConsensusGroup() { + pkBytes := []byte(pk) + if !sr.IsKeyManagedBySelf(pkBytes) { + continue + } + + if sr.IsJobDone(pk, sr.Current()) { + continue + } + + err := sr.checkGoRoutinesThrottler(ctx) + if err != nil { + return false + } + sr.signatureThrottler.StartProcessing() + wg.Add(1) + + go func(idx int, pk string) { + defer sr.signatureThrottler.EndProcessing() + + signatureSent := sr.sendSignatureForManagedKey(idx, pk) + if signatureSent { + atomic.AddInt32(&numMultiKeysSignaturesSent, 1) + } else { + sentSigForAllKeys.SetValue(false) + } + wg.Done() + }(idx, pk) + } + + wg.Wait() + + if numMultiKeysSignaturesSent > 0 { + log.Debug("step 2: multi keys signatures have been sent", "num", numMultiKeysSignaturesSent) + } + + return sentSigForAllKeys.IsSet() +} + +func (sr *subroundSignature) sendSignatureForManagedKey(idx int, pk string) bool { + pkBytes := []byte(pk) + + signatureShare, err := sr.SigningHandler().CreateSignatureShareForPublicKey( + sr.GetData(), + uint16(idx), + sr.GetHeader().GetEpoch(), + pkBytes, + ) + if err != nil { + log.Debug("sendSignatureForManagedKey.CreateSignatureShareForPublicKey", "error", err.Error()) + return false + } + + // with the equivalent messages feature on, signatures from all managed keys must be broadcast, as the aggregation is done by any participant + ok := sr.createAndSendSignatureMessage(signatureShare, pkBytes) + if !ok { + return false + } + sr.sentSignatureTracker.SignatureSent(pkBytes) + + return sr.completeSignatureSubRound(pk) +} + +func (sr *subroundSignature) checkGoRoutinesThrottler(ctx context.Context) error { + for { + if sr.signatureThrottler.CanProcess() { + break + } + select { + case <-time.After(timeSpentBetweenChecks): + continue + case <-ctx.Done(): + return fmt.Errorf("%w while checking the throttler", spos.ErrTimeIsOut) + } + } + return nil +} + +func (sr *subroundSignature) doSignatureJobForSingleKey() bool { + selfIndex, err := sr.SelfConsensusGroupIndex() + if err != nil { + log.Debug("doSignatureJobForSingleKey.SelfConsensusGroupIndex: not in consensus group") + return false + } + + signatureShare, err := sr.SigningHandler().CreateSignatureShareForPublicKey( + sr.GetData(), + uint16(selfIndex), + sr.GetHeader().GetEpoch(), + []byte(sr.SelfPubKey()), + ) + if err != nil { + log.Debug("doSignatureJobForSingleKey.CreateSignatureShareForPublicKey", "error", err.Error()) + return false + } + + // leader also sends his signature here + ok := sr.createAndSendSignatureMessage(signatureShare, []byte(sr.SelfPubKey())) + if !ok { + return false + } + + return sr.completeSignatureSubRound(sr.SelfPubKey()) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (sr *subroundSignature) IsInterfaceNil() bool { + return sr == nil +} diff --git a/consensus/spos/bls/v2/subroundSignature_test.go b/consensus/spos/bls/v2/subroundSignature_test.go new file mode 100644 index 00000000000..40470f48f1d --- /dev/null +++ b/consensus/spos/bls/v2/subroundSignature_test.go @@ -0,0 +1,1029 @@ +package v2_test + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + dataRetrieverMock "github.com/multiversx/mx-chain-go/dataRetriever/mock" + "github.com/multiversx/mx-chain-go/testscommon" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +const setThresholdJobsDone = "threshold" + +func initSubroundSignatureWithContainer(container *spos.ConsensusCore) v2.SubroundSignature { + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + return srSignature +} + +func initSubroundSignature() v2.SubroundSignature { + container := consensusMocks.InitConsensusCore() + return initSubroundSignatureWithContainer(container) +} + +func TestNewSubroundSignature(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + t.Run("nil subround should error", func(t *testing.T) { + t.Parallel() + + srSignature, err := v2.NewSubroundSignature( + nil, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.Nil(t, srSignature) + assert.Equal(t, spos.ErrNilSubround, err) + }) + t.Run("nil worker should error", func(t *testing.T) { + t.Parallel() + + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + nil, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.Nil(t, srSignature) + assert.Equal(t, spos.ErrNilWorker, err) + }) + t.Run("nil app status handler should error", func(t *testing.T) { + t.Parallel() + + srSignature, err := v2.NewSubroundSignature( + sr, + nil, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.Nil(t, srSignature) + assert.Equal(t, spos.ErrNilAppStatusHandler, err) + }) + t.Run("nil sent signatures tracker should error", func(t *testing.T) { + t.Parallel() + + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + nil, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.Nil(t, srSignature) + assert.Equal(t, v2.ErrNilSentSignatureTracker, err) + }) + + t.Run("nil signatureThrottler should error", func(t *testing.T) { + t.Parallel() + + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + nil, + ) + + assert.Nil(t, srSignature) + assert.Equal(t, spos.ErrNilThrottler, err) + }) +} + +func TestSubroundSignature_NewSubroundSignatureNilConsensusStateShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + sr.ConsensusStateHandler = nil + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srSignature)) + assert.Equal(t, spos.ErrNilConsensusState, err) +} + +func TestSubroundSignature_NewSubroundSignatureNilHasherShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetHasher(nil) + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srSignature)) + assert.Equal(t, spos.ErrNilHasher, err) +} + +func TestSubroundSignature_NewSubroundSignatureNilMultiSignerContainerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetMultiSignerContainer(nil) + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srSignature)) + assert.Equal(t, spos.ErrNilMultiSignerContainer, err) +} + +func TestSubroundSignature_NewSubroundSignatureNilRoundHandlerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetRoundHandler(nil) + + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srSignature)) + assert.Equal(t, spos.ErrNilRoundHandler, err) +} + +func TestSubroundSignature_NewSubroundSignatureNilSyncTimerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + container.SetSyncTimer(nil) + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srSignature)) + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestSubroundSignature_NewSubroundSignatureNilAppStatusHandlerShouldFail(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srSignature, err := v2.NewSubroundSignature( + sr, + nil, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.True(t, check.IfNil(srSignature)) + assert.Equal(t, spos.ErrNilAppStatusHandler, err) +} + +func TestSubroundSignature_NewSubroundSignatureShouldWork(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srSignature, err := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + assert.False(t, check.IfNil(srSignature)) + assert.Nil(t, err) +} + +func TestSubroundSignature_DoSignatureJob(t *testing.T) { + t.Parallel() + + t.Run("job done should return false", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundSignatureWithContainer(container) + sr.SetStatus(bls.SrSignature, spos.SsFinished) + + r := sr.DoSignatureJob() + assert.False(t, r) + }) + t.Run("nil header should return false", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundSignatureWithContainer(container) + sr.SetHeader(nil) + + r := sr.DoSignatureJob() + assert.False(t, r) + }) + t.Run("proof already received should return true", func(t *testing.T) { + t.Parallel() + + providedHash := []byte("providedHash") + container := consensusMocks.InitConsensusCore() + container.SetEquivalentProofsPool(&dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return string(headerHash) == string(providedHash) + }, + }) + sr := initSubroundSignatureWithContainer(container) + sr.SetData(providedHash) + sr.SetHeader(&block.Header{}) + + r := sr.DoSignatureJob() + assert.True(t, r) + }) + t.Run("single key error should return false", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundSignatureWithContainer(container) + + sr.SetHeader(&block.Header{}) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + container.SetBroadcastMessenger(&consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + return expectedErr + }, + }) + r := sr.DoSignatureJob() + assert.False(t, r) + }) + t.Run("single key mode should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundSignatureWithContainer(container) + + sr.SetHeader(&block.Header{}) + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + container.SetBroadcastMessenger(&consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + if string(message.PubKey) != leader || message.MsgType != int64(bls.MtSignature) { + assert.Fail(t, "should have not been called") + } + return nil + }, + }) + r := sr.DoSignatureJob() + assert.True(t, r) + + assert.False(t, sr.GetRoundCanceled()) + assert.Nil(t, err) + leaderJobDone, err := sr.JobDone(leader, bls.SrSignature) + assert.NoError(t, err) + assert.True(t, leaderJobDone) + assert.True(t, sr.IsSubroundFinished(bls.SrSignature)) + }) + t.Run("multikey mode should work", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + signingHandler := &consensusMocks.SigningHandlerStub{ + CreateSignatureShareForPublicKeyCalled: func(msg []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { + return []byte("SIG"), nil + }, + } + container.SetSigningHandler(signingHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler( + &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + ) + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + signatureSentForPks := make(map[string]struct{}) + mutex := sync.Mutex{} + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{ + SignatureSentCalled: func(pkBytes []byte) { + mutex.Lock() + signatureSentForPks[string(pkBytes)] = struct{}{} + mutex.Unlock() + }, + }, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + sr.SetHeader(&block.Header{}) + signaturesBroadcast := make(map[string]int) + container.SetBroadcastMessenger(&consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + mutex.Lock() + signaturesBroadcast[string(message.PubKey)]++ + mutex.Unlock() + return nil + }, + }) + + sr.SetSelfPubKey("OTHER") + + r := srSignature.DoSignatureJob() + assert.True(t, r) + + assert.False(t, sr.GetRoundCanceled()) + assert.True(t, sr.IsSubroundFinished(bls.SrSignature)) + + for _, pk := range sr.ConsensusGroup() { + isJobDone, err := sr.JobDone(pk, bls.SrSignature) + assert.NoError(t, err) + assert.True(t, isJobDone) + } + + expectedMap := map[string]struct{}{"A": {}, "B": {}, "C": {}, "D": {}, "E": {}, "F": {}, "G": {}, "H": {}, "I": {}} + assert.Equal(t, expectedMap, signatureSentForPks) + + // leader also sends his signature + expectedBroadcastMap := map[string]int{"A": 1, "B": 1, "C": 1, "D": 1, "E": 1, "F": 1, "G": 1, "H": 1, "I": 1} + assert.Equal(t, expectedBroadcastMap, signaturesBroadcast) + }) +} + +func TestSubroundSignature_SendSignature(t *testing.T) { + t.Parallel() + + t.Run("sendSignatureForManagedKey will return false because of error", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + + container.SetSigningHandler(&consensusMocks.SigningHandlerStub{ + CreateSignatureShareForPublicKeyCalled: func(message []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { + return make([]byte, 0), expErr + }, + }) + consensusState := initializers.InitConsensusStateWithKeysHandler( + &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + ) + + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetHeader(&block.Header{}) + + signatureSentForPks := make(map[string]struct{}) + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{ + SignatureSentCalled: func(pkBytes []byte) { + signatureSentForPks[string(pkBytes)] = struct{}{} + }, + }, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + r := srSignature.SendSignatureForManagedKey(0, "a") + + assert.False(t, r) + }) + + t.Run("sendSignatureForManagedKey should be false", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + container.SetSigningHandler(&consensusMocks.SigningHandlerStub{ + CreateSignatureShareForPublicKeyCalled: func(message []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { + return []byte("SIG"), nil + }, + }) + + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + container.SetBroadcastMessenger(&consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + return fmt.Errorf("error") + }, + }) + consensusState := initializers.InitConsensusStateWithKeysHandler( + &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + ) + + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetHeader(&block.Header{}) + + signatureSentForPks := make(map[string]struct{}) + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{ + SignatureSentCalled: func(pkBytes []byte) { + signatureSentForPks[string(pkBytes)] = struct{}{} + }, + }, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + r := srSignature.SendSignatureForManagedKey(1, "a") + + assert.False(t, r) + }) + + t.Run("SentSignature should be called", func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + container.SetSigningHandler(&consensusMocks.SigningHandlerStub{ + CreateSignatureShareForPublicKeyCalled: func(message []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { + return []byte("SIG"), nil + }, + }) + + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + container.SetBroadcastMessenger(&consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + return nil + }, + }) + consensusState := initializers.InitConsensusStateWithKeysHandler( + &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + ) + + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetHeader(&block.Header{}) + + signatureSentForPks := make(map[string]struct{}) + varCalled := false + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{ + SignatureSentCalled: func(pkBytes []byte) { + signatureSentForPks[string(pkBytes)] = struct{}{} + varCalled = true + }, + }, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + _ = srSignature.SendSignatureForManagedKey(1, "a") + + assert.True(t, varCalled) + }) +} + +func TestSubroundSignature_DoSignatureJobForManagedKeys(t *testing.T) { + t.Parallel() + + t.Run("should work", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + signingHandler := &consensusMocks.SigningHandlerStub{ + CreateSignatureShareForPublicKeyCalled: func(msg []byte, index uint16, epoch uint32, publicKeyBytes []byte) ([]byte, error) { + return []byte("SIG"), nil + }, + } + container.SetSigningHandler(signingHandler) + consensusState := initializers.InitConsensusStateWithKeysHandler( + &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + ) + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + signatureSentForPks := make(map[string]struct{}) + mutex := sync.Mutex{} + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{ + SignatureSentCalled: func(pkBytes []byte) { + mutex.Lock() + signatureSentForPks[string(pkBytes)] = struct{}{} + mutex.Unlock() + }, + }, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{}, + ) + + sr.SetHeader(&block.Header{}) + signaturesBroadcast := make(map[string]int) + container.SetBroadcastMessenger(&consensusMocks.BroadcastMessengerMock{ + BroadcastConsensusMessageCalled: func(message *consensus.Message) error { + mutex.Lock() + signaturesBroadcast[string(message.PubKey)]++ + mutex.Unlock() + return nil + }, + }) + + sr.SetSelfPubKey("OTHER") + + r := srSignature.DoSignatureJobForManagedKeys(context.TODO()) + assert.True(t, r) + + for _, pk := range sr.ConsensusGroup() { + isJobDone, err := sr.JobDone(pk, bls.SrSignature) + assert.NoError(t, err) + assert.True(t, isJobDone) + } + + expectedMap := map[string]struct{}{"A": {}, "B": {}, "C": {}, "D": {}, "E": {}, "F": {}, "G": {}, "H": {}, "I": {}} + assert.Equal(t, expectedMap, signatureSentForPks) + + expectedBroadcastMap := map[string]int{"A": 1, "B": 1, "C": 1, "D": 1, "E": 1, "F": 1, "G": 1, "H": 1, "I": 1} + assert.Equal(t, expectedBroadcastMap, signaturesBroadcast) + }) + + t.Run("should fail", func(t *testing.T) { + t.Parallel() + container := consensusMocks.InitConsensusCore() + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + container.SetEnableEpochsHandler(enableEpochsHandler) + + consensusState := initializers.InitConsensusStateWithKeysHandler( + &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return true + }, + }, + ) + ch := make(chan bool, 1) + + sr, _ := spos.NewSubround( + bls.SrBlock, + bls.SrSignature, + bls.SrEndRound, + int64(70*roundTimeDuration/100), + int64(85*roundTimeDuration/100), + "(SIGNATURE)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + srSignature, _ := v2.NewSubroundSignature( + sr, + &statusHandler.AppStatusHandlerStub{}, + &testscommon.SentSignatureTrackerStub{}, + &consensusMocks.SposWorkerMock{}, + &dataRetrieverMock.ThrottlerStub{ + CanProcessCalled: func() bool { + return false + }, + }, + ) + + sr.SetHeader(&block.Header{}) + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + r := srSignature.DoSignatureJobForManagedKeys(ctx) + assert.False(t, r) + }) +} + +func TestSubroundSignature_DoSignatureConsensusCheck(t *testing.T) { + t.Parallel() + + t.Run("round canceled should return false", func(t *testing.T) { + t.Parallel() + + sr := initSubroundSignature() + sr.SetRoundCanceled(true) + assert.False(t, sr.DoSignatureConsensusCheck()) + }) + t.Run("subround already finished should return true", func(t *testing.T) { + t.Parallel() + + sr := initSubroundSignature() + sr.SetStatus(bls.SrSignature, spos.SsFinished) + assert.True(t, sr.DoSignatureConsensusCheck()) + }) + t.Run("sig collection done should return true", func(t *testing.T) { + t.Parallel() + + sr := initSubroundSignature() + + for i := 0; i < sr.Threshold(bls.SrSignature); i++ { + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) + } + + sr.SetHeader(&block.HeaderV2{}) + assert.True(t, sr.DoSignatureConsensusCheck()) + }) + t.Run("sig collection failed should return false", func(t *testing.T) { + t.Parallel() + + sr := initSubroundSignature() + sr.SetHeader(&block.HeaderV2{Header: createDefaultHeader()}) + assert.False(t, sr.DoSignatureConsensusCheck()) + }) + t.Run("not all sig collected in time should return false", func(t *testing.T) { + t.Parallel() + + sr := initSubroundSignature() + sr.SetHeader(&block.HeaderV2{Header: createDefaultHeader()}) + assert.False(t, sr.DoSignatureConsensusCheck()) + }) + t.Run("nil header should return false", func(t *testing.T) { + t.Parallel() + + sr := initSubroundSignature() + sr.SetHeader(nil) + assert.False(t, sr.DoSignatureConsensusCheck()) + }) + t.Run("node not in consensus group should return true", func(t *testing.T) { + t.Parallel() + + sr := initSubroundSignature() + sr.SetHeader(&block.HeaderV2{Header: createDefaultHeader()}) + sr.SetSelfPubKey("X") + assert.True(t, sr.DoSignatureConsensusCheck()) + }) +} + +func TestSubroundSignature_DoSignatureConsensusCheckAllSignaturesCollected(t *testing.T) { + t.Parallel() + t.Run("with flag active, should return true", testSubroundSignatureDoSignatureConsensusCheck(argTestSubroundSignatureDoSignatureConsensusCheck{ + flagActive: true, + jobsDone: "all", + expectedResult: true, + })) +} + +func TestSubroundSignature_DoSignatureConsensusCheckEnoughButNotAllSignaturesCollectedAndTimeIsOut(t *testing.T) { + t.Parallel() + + t.Run("with flag active, should return true", testSubroundSignatureDoSignatureConsensusCheck(argTestSubroundSignatureDoSignatureConsensusCheck{ + flagActive: true, + jobsDone: setThresholdJobsDone, + expectedResult: true, + })) +} + +type argTestSubroundSignatureDoSignatureConsensusCheck struct { + flagActive bool + jobsDone string + expectedResult bool +} + +func testSubroundSignatureDoSignatureConsensusCheck(args argTestSubroundSignatureDoSignatureConsensusCheck) func(t *testing.T) { + return func(t *testing.T) { + t.Parallel() + + container := consensusMocks.InitConsensusCore() + sr := initSubroundSignatureWithContainer(container) + + if !args.flagActive { + leader, err := sr.GetLeader() + assert.Nil(t, err) + sr.SetSelfPubKey(leader) + } + + numberOfJobsDone := sr.ConsensusGroupSize() + if args.jobsDone == setThresholdJobsDone { + numberOfJobsDone = sr.Threshold(bls.SrSignature) + } + for i := 0; i < numberOfJobsDone; i++ { + _ = sr.SetJobDone(sr.ConsensusGroup()[i], bls.SrSignature, true) + } + + sr.SetHeader(&block.HeaderV2{}) + assert.Equal(t, args.expectedResult, sr.DoSignatureConsensusCheck()) + } +} diff --git a/consensus/spos/bls/v2/subroundStartRound.go b/consensus/spos/bls/v2/subroundStartRound.go new file mode 100644 index 00000000000..4e3be13f5cd --- /dev/null +++ b/consensus/spos/bls/v2/subroundStartRound.go @@ -0,0 +1,354 @@ +package v2 + +import ( + "context" + "encoding/hex" + "fmt" + "sync" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + outportcore "github.com/multiversx/mx-chain-core-go/data/outport" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/outport" + "github.com/multiversx/mx-chain-go/outport/disabled" +) + +// subroundStartRound defines the data needed by the subround StartRound +type subroundStartRound struct { + *spos.Subround + processingThresholdPercentage int + + sentSignatureTracker spos.SentSignaturesTracker + worker spos.WorkerHandler + outportHandler outport.OutportHandler + outportMutex sync.RWMutex +} + +// NewSubroundStartRound creates a subroundStartRound object +func NewSubroundStartRound( + baseSubround *spos.Subround, + processingThresholdPercentage int, + sentSignatureTracker spos.SentSignaturesTracker, + worker spos.WorkerHandler, +) (*subroundStartRound, error) { + err := checkNewSubroundStartRoundParams( + baseSubround, + ) + if err != nil { + return nil, err + } + if check.IfNil(sentSignatureTracker) { + return nil, ErrNilSentSignatureTracker + } + if check.IfNil(worker) { + return nil, spos.ErrNilWorker + } + + srStartRound := subroundStartRound{ + Subround: baseSubround, + processingThresholdPercentage: processingThresholdPercentage, + sentSignatureTracker: sentSignatureTracker, + worker: worker, + outportHandler: disabled.NewDisabledOutport(), + outportMutex: sync.RWMutex{}, + } + srStartRound.Job = srStartRound.doStartRoundJob + srStartRound.Check = srStartRound.doStartRoundConsensusCheck + srStartRound.Extend = worker.Extend + baseSubround.EpochStartRegistrationHandler().RegisterHandler(&srStartRound) + + return &srStartRound, nil +} + +func checkNewSubroundStartRoundParams( + baseSubround *spos.Subround, +) error { + if baseSubround == nil { + return spos.ErrNilSubround + } + if check.IfNil(baseSubround.ConsensusStateHandler) { + return spos.ErrNilConsensusState + } + + err := spos.ValidateConsensusCore(baseSubround.ConsensusCoreHandler) + + return err +} + +// SetOutportHandler method sets outport handler +func (sr *subroundStartRound) SetOutportHandler(outportHandler outport.OutportHandler) error { + if check.IfNil(outportHandler) { + return outport.ErrNilDriver + } + + sr.outportMutex.Lock() + sr.outportHandler = outportHandler + sr.outportMutex.Unlock() + + return nil +} + +// doStartRoundJob method does the job of the subround StartRound +func (sr *subroundStartRound) doStartRoundJob(_ context.Context) bool { + sr.ResetConsensusState() + sr.SetRoundIndex(sr.RoundHandler().Index()) + sr.SetRoundTimeStamp(sr.RoundHandler().TimeStamp()) + topic := spos.GetConsensusTopicID(sr.ShardCoordinator()) + sr.GetAntiFloodHandler().ResetForTopic(topic) + sr.worker.ResetConsensusMessages() + sr.worker.ResetInvalidSignersCache() + + return true +} + +// doStartRoundConsensusCheck method checks if the consensus is achieved in the subround StartRound +func (sr *subroundStartRound) doStartRoundConsensusCheck() bool { + if sr.GetRoundCanceled() { + return false + } + + if sr.IsSubroundFinished(sr.Current()) { + return true + } + + if sr.initCurrentRound() { + return true + } + + return false +} + +func (sr *subroundStartRound) initCurrentRound() bool { + nodeState := sr.BootStrapper().GetNodeState() + if nodeState != common.NsSynchronized { // if node is not synchronized yet, it has to continue the bootstrapping mechanism + return false + } + + sr.AppStatusHandler().SetStringValue(common.MetricConsensusRoundState, "") + + err := sr.generateNextConsensusGroup(sr.RoundHandler().Index()) + if err != nil { + log.Debug("initCurrentRound.generateNextConsensusGroup", + "round index", sr.RoundHandler().Index(), + "error", err.Error()) + + sr.SetRoundCanceled(true) + + return false + } + + if sr.NodeRedundancyHandler().IsRedundancyNode() { + sr.NodeRedundancyHandler().AdjustInactivityIfNeeded( + sr.SelfPubKey(), + sr.ConsensusGroup(), + sr.RoundHandler().Index(), + ) + // we should not return here, the multikey redundancy system relies on it + // the NodeRedundancyHandler "thinks" it is in redundancy mode even if we use the multikey redundancy system + } + + leader, err := sr.GetLeader() + if err != nil { + log.Debug("initCurrentRound.GetLeader", "error", err.Error()) + + sr.SetRoundCanceled(true) + + return false + } + + msg := sr.GetLeaderStartRoundMessage() + if len(msg) != 0 { + sr.AppStatusHandler().Increment(common.MetricCountLeader) + sr.AppStatusHandler().SetStringValue(common.MetricConsensusRoundState, "proposed") + sr.AppStatusHandler().SetStringValue(common.MetricConsensusState, "proposer") + } + + log.Debug("step 0: preparing the round", + "leader", core.GetTrimmedPk(hex.EncodeToString([]byte(leader))), + "messsage", msg) + sr.sentSignatureTracker.StartRound() + + pubKeys := sr.ConsensusGroup() + numMultiKeysInConsensusGroup := sr.computeNumManagedKeysInConsensusGroup(pubKeys) + if numMultiKeysInConsensusGroup > 0 { + log.Debug("in consensus group with multi keys identities", "num", numMultiKeysInConsensusGroup) + } + + sr.indexRoundIfNeeded(pubKeys) + + if !sr.IsSelfInConsensusGroup() { + log.Debug("not in consensus group") + sr.AppStatusHandler().SetStringValue(common.MetricConsensusState, "not in consensus group") + } else { + if !sr.IsSelfLeader() { + sr.AppStatusHandler().Increment(common.MetricCountConsensus) + sr.AppStatusHandler().SetStringValue(common.MetricConsensusState, "participant") + } + } + + err = sr.SigningHandler().Reset(pubKeys) + if err != nil { + log.Debug("initCurrentRound.Reset", "error", err.Error()) + + sr.SetRoundCanceled(true) + + return false + } + + startTime := sr.GetRoundTimeStamp() + maxTime := sr.RoundHandler().TimeDuration() * time.Duration(sr.processingThresholdPercentage) / 100 + if sr.RoundHandler().RemainingTime(startTime, maxTime) < 0 { + log.Debug("canceled round, time is out", + "round", sr.SyncTimer().FormattedCurrentTime(), sr.RoundHandler().Index(), + "subround", sr.Name()) + + sr.SetRoundCanceled(true) + + return false + } + + sr.SetStatus(sr.Current(), spos.SsFinished) + + // execute stored messages which were received in this new round but before this initialisation + go sr.worker.ExecuteStoredMessages() + + return true +} + +func (sr *subroundStartRound) computeNumManagedKeysInConsensusGroup(pubKeys []string) int { + numMultiKeysInConsensusGroup := 0 + for _, pk := range pubKeys { + pkBytes := []byte(pk) + if sr.IsKeyManagedBySelf(pkBytes) { + numMultiKeysInConsensusGroup++ + log.Trace("in consensus group with multi key", + "pk", core.GetTrimmedPk(hex.EncodeToString(pkBytes))) + } + sr.IncrementRoundsWithoutReceivedMessages(pkBytes) + } + + return numMultiKeysInConsensusGroup +} + +func (sr *subroundStartRound) indexRoundIfNeeded(pubKeys []string) { + sr.outportMutex.RLock() + defer sr.outportMutex.RUnlock() + + if !sr.outportHandler.HasDrivers() { + return + } + + currentHeader := sr.Blockchain().GetCurrentBlockHeader() + if check.IfNil(currentHeader) { + currentHeader = sr.Blockchain().GetGenesisHeader() + } + + epoch := currentHeader.GetEpoch() + shardId := sr.ShardCoordinator().SelfId() + nodesCoordinatorShardID, err := sr.NodesCoordinator().ShardIdForEpoch(epoch) + if err != nil { + log.Debug("initCurrentRound.ShardIdForEpoch", + "epoch", epoch, + "error", err.Error()) + return + } + + if shardId != nodesCoordinatorShardID { + log.Debug("initCurrentRound.ShardIdForEpoch", + "epoch", epoch, + "shardCoordinator.ShardID", shardId, + "nodesCoordinator.ShardID", nodesCoordinatorShardID) + return + } + + round := sr.RoundHandler().Index() + + roundInfo := &outportcore.RoundInfo{ + Round: uint64(round), + SignersIndexes: make([]uint64, 0), + BlockWasProposed: false, + ShardId: shardId, + Epoch: epoch, + Timestamp: uint64(sr.GetRoundTimeStamp().Unix()), + } + roundsInfo := &outportcore.RoundsInfo{ + ShardID: shardId, + RoundsInfo: []*outportcore.RoundInfo{roundInfo}, + } + sr.outportHandler.SaveRoundsInfo(roundsInfo) +} + +func (sr *subroundStartRound) generateNextConsensusGroup(roundIndex int64) error { + currentHeader := sr.Blockchain().GetCurrentBlockHeader() + if check.IfNil(currentHeader) { + currentHeader = sr.Blockchain().GetGenesisHeader() + if check.IfNil(currentHeader) { + return spos.ErrNilHeader + } + } + + randomSeed := currentHeader.GetRandSeed() + + log.Debug("random source for the next consensus group", + "rand", randomSeed) + + shardId := sr.ShardCoordinator().SelfId() + + leader, nextConsensusGroup, err := sr.GetNextConsensusGroup( + randomSeed, + uint64(sr.GetRoundIndex()), + shardId, + sr.NodesCoordinator(), + currentHeader.GetEpoch(), + ) + if err != nil { + return err + } + + log.Trace("consensus group is formed by next validators:", + "round", roundIndex) + + for i := 0; i < len(nextConsensusGroup); i++ { + log.Trace(core.GetTrimmedPk(hex.EncodeToString([]byte(nextConsensusGroup[i])))) + } + + sr.SetConsensusGroup(nextConsensusGroup) + sr.SetLeader(leader) + + consensusGroupSizeForEpoch := sr.NodesCoordinator().ConsensusGroupSizeForShardAndEpoch(shardId, currentHeader.GetEpoch()) + sr.SetConsensusGroupSize(consensusGroupSizeForEpoch) + + return nil +} + +// EpochStartPrepare wis called when an epoch start event is observed, but not yet confirmed/committed. +// Some components may need to do initialisation on this event +func (sr *subroundStartRound) EpochStartPrepare(metaHdr data.HeaderHandler, _ data.BodyHandler) { + log.Trace(fmt.Sprintf("epoch %d start prepare in consensus", metaHdr.GetEpoch())) +} + +// EpochStartAction is called upon a start of epoch event. +func (sr *subroundStartRound) EpochStartAction(hdr data.HeaderHandler) { + log.Trace(fmt.Sprintf("epoch %d start action in consensus", hdr.GetEpoch())) + + sr.changeEpoch(hdr.GetEpoch()) +} + +func (sr *subroundStartRound) changeEpoch(currentEpoch uint32) { + epochNodes, err := sr.NodesCoordinator().GetConsensusWhitelistedNodes(currentEpoch) + if err != nil { + panic(fmt.Sprintf("consensus changing epoch failed with error %s", err.Error())) + } + + sr.SetEligibleList(epochNodes) +} + +// NotifyOrder returns the notification order for a start of epoch event +func (sr *subroundStartRound) NotifyOrder() uint32 { + return common.ConsensusStartRoundOrder +} diff --git a/consensus/spos/bls/v2/subroundStartRound_test.go b/consensus/spos/bls/v2/subroundStartRound_test.go new file mode 100644 index 00000000000..fa6853f0314 --- /dev/null +++ b/consensus/spos/bls/v2/subroundStartRound_test.go @@ -0,0 +1,1117 @@ +package v2_test + +import ( + "fmt" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/data" + outportcore "github.com/multiversx/mx-chain-core-go/data/outport" + "github.com/stretchr/testify/require" + + v2 "github.com/multiversx/mx-chain-go/consensus/spos/bls/v2" + processMock "github.com/multiversx/mx-chain-go/process/mock" + "github.com/multiversx/mx-chain-go/testscommon/bootstrapperStubs" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/consensus/initializers" + "github.com/multiversx/mx-chain-go/testscommon/outport" + + "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" +) + +var expErr = fmt.Errorf("expected error") + +func defaultSubroundStartRoundFromSubround(sr *spos.Subround) (v2.SubroundStartRound, error) { + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + + return startRound, err +} + +func defaultWithoutErrorSubroundStartRoundFromSubround(sr *spos.Subround) v2.SubroundStartRound { + startRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + + return startRound +} + +func defaultSubround( + consensusState *spos.ConsensusState, + ch chan bool, + container spos.ConsensusCoreHandler, +) (*spos.Subround, error) { + + return spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(0*roundTimeDuration/100), + int64(5*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) +} + +func initSubroundStartRoundWithContainer(container spos.ConsensusCoreHandler) v2.SubroundStartRound { + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + sr, _ := defaultSubround(consensusState, ch, container) + srStartRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + + return srStartRound +} + +func initSubroundStartRound() v2.SubroundStartRound { + container := consensus.InitConsensusCore() + return initSubroundStartRoundWithContainer(container) +} + +func TestNewSubroundStartRound(t *testing.T) { + t.Parallel() + + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusState() + container := consensus.InitConsensusCore() + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + t.Run("nil subround should error", func(t *testing.T) { + t.Parallel() + + srStartRound, err := v2.NewSubroundStartRound( + nil, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilSubround, err) + }) + t.Run("nil sent signatures tracker should error", func(t *testing.T) { + t.Parallel() + + srStartRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + nil, + &consensus.SposWorkerMock{}, + ) + + assert.Nil(t, srStartRound) + assert.Equal(t, v2.ErrNilSentSignatureTracker, err) + }) + t.Run("nil worker should error", func(t *testing.T) { + t.Parallel() + + srStartRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + nil, + ) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilWorker, err) + }) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilBlockChainShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + container.SetBlockchain(nil) + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilBlockChain, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilBootstrapperShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + container.SetBootStrapper(nil) + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilBootstrapper, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilConsensusStateShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + + sr.ConsensusStateHandler = nil + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilConsensusState, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilMultiSignerContainerShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + container.SetMultiSignerContainer(nil) + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilMultiSignerContainer, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilRoundHandlerShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + container.SetRoundHandler(nil) + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilRoundHandler, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilSyncTimerShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + container.SetSyncTimer(nil) + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilSyncTimer, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundNilValidatorGroupSelectorShouldFail(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + container.SetNodesCoordinator(nil) + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.Nil(t, srStartRound) + assert.Equal(t, spos.ErrNilNodesCoordinator, err) +} + +func TestSubroundStartRound_NewSubroundStartRoundShouldWork(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + + srStartRound, err := defaultSubroundStartRoundFromSubround(sr) + + assert.NotNil(t, srStartRound) + assert.Nil(t, err) +} + +func TestSubroundStartRound_DoStartRoundShouldReturnTrue(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + consensusState := initializers.InitConsensusState() + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + + srStartRound := defaultWithoutErrorSubroundStartRoundFromSubround(sr) + + r := srStartRound.DoStartRoundJob() + assert.True(t, r) +} + +func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnFalseWhenRoundIsCanceled(t *testing.T) { + t.Parallel() + + sr := initSubroundStartRound() + + sr.SetRoundCanceled(true) + + ok := sr.DoStartRoundConsensusCheck() + assert.False(t, ok) +} + +func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnTrueWhenRoundIsFinished(t *testing.T) { + t.Parallel() + + sr := initSubroundStartRound() + + sr.SetStatus(bls.SrStartRound, spos.SsFinished) + + ok := sr.DoStartRoundConsensusCheck() + assert.True(t, ok) +} + +func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnTrueWhenInitCurrentRoundReturnTrue(t *testing.T) { + t.Parallel() + + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{GetNodeStateCalled: func() common.NodeState { + return common.NsSynchronized + }} + + container := consensus.InitConsensusCore() + container.SetBootStrapper(bootstrapperMock) + + sr := initSubroundStartRoundWithContainer(container) + sentTrackerInterface := sr.GetSentSignatureTracker() + sentTracker := sentTrackerInterface.(*testscommon.SentSignatureTrackerStub) + startRoundCalled := false + sentTracker.StartRoundCalled = func() { + startRoundCalled = true + } + + ok := sr.DoStartRoundConsensusCheck() + assert.True(t, ok) + assert.True(t, startRoundCalled) +} + +func TestSubroundStartRound_DoStartRoundConsensusCheckShouldReturnFalseWhenInitCurrentRoundReturnFalse(t *testing.T) { + t.Parallel() + + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{GetNodeStateCalled: func() common.NodeState { + return common.NsNotSynchronized + }} + + container := consensus.InitConsensusCore() + container.SetBootStrapper(bootstrapperMock) + container.SetRoundHandler(initRoundHandlerMock()) + + sr := initSubroundStartRoundWithContainer(container) + + ok := sr.DoStartRoundConsensusCheck() + assert.False(t, ok) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGetNodeStateNotReturnSynchronized(t *testing.T) { + t.Parallel() + + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} + + bootstrapperMock.GetNodeStateCalled = func() common.NodeState { + return common.NsNotSynchronized + } + container := consensus.InitConsensusCore() + container.SetBootStrapper(bootstrapperMock) + + srStartRound := initSubroundStartRoundWithContainer(container) + + r := srStartRound.InitCurrentRound() + assert.False(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGenerateNextConsensusGroupErr(t *testing.T) { + t.Parallel() + + validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{} + + validatorGroupSelector.ComputeValidatorsGroupCalled = func(bytes []byte, round uint64, shardId uint32, epoch uint32) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + return nil, nil, expErr + } + container := consensus.InitConsensusCore() + + container.SetNodesCoordinator(validatorGroupSelector) + + srStartRound := initSubroundStartRoundWithContainer(container) + + r := srStartRound.InitCurrentRound() + assert.False(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnTrueWhenMainMachineIsActive(t *testing.T) { + t.Parallel() + + nodeRedundancyMock := &mock.NodeRedundancyHandlerStub{ + IsRedundancyNodeCalled: func() bool { + return true + }, + } + container := consensus.InitConsensusCore() + container.SetNodeRedundancyHandler(nodeRedundancyMock) + + srStartRound := initSubroundStartRoundWithContainer(container) + + r := srStartRound.InitCurrentRound() + assert.True(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenGetLeaderErr(t *testing.T) { + t.Parallel() + + validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{} + leader := &shardingMocks.ValidatorMock{PubKeyCalled: func() []byte { + return []byte("leader") + }} + + validatorGroupSelector.ComputeValidatorsGroupCalled = func( + bytes []byte, + round uint64, + shardId uint32, + epoch uint32, + ) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + // will cause an error in GetLeader because of empty consensus group + return leader, []nodesCoordinator.Validator{}, nil + } + + container := consensus.InitConsensusCore() + container.SetNodesCoordinator(validatorGroupSelector) + + srStartRound := initSubroundStartRoundWithContainer(container) + + r := srStartRound.InitCurrentRound() + assert.False(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnTrueWhenIsNotInTheConsensusGroup(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + consensusState := initializers.InitConsensusState() + consensusState.SetSelfPubKey(consensusState.SelfPubKey() + "X") + ch := make(chan bool, 1) + + sr, _ := defaultSubround(consensusState, ch, container) + + srStartRound := defaultWithoutErrorSubroundStartRoundFromSubround(sr) + + r := srStartRound.InitCurrentRound() + assert.True(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenTimeIsOut(t *testing.T) { + t.Parallel() + + roundHandlerMock := initRoundHandlerMock() + + roundHandlerMock.RemainingTimeCalled = func(time.Time, time.Duration) time.Duration { + return time.Duration(-1) + } + + container := consensus.InitConsensusCore() + container.SetRoundHandler(roundHandlerMock) + + srStartRound := initSubroundStartRoundWithContainer(container) + + r := srStartRound.InitCurrentRound() + assert.False(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnTrue(t *testing.T) { + t.Parallel() + + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} + + bootstrapperMock.GetNodeStateCalled = func() common.NodeState { + return common.NsSynchronized + } + + container := consensus.InitConsensusCore() + container.SetBootStrapper(bootstrapperMock) + + srStartRound := initSubroundStartRoundWithContainer(container) + + r := srStartRound.InitCurrentRound() + assert.True(t, r) +} + +func TestSubroundStartRound_InitCurrentRoundShouldMetrics(t *testing.T) { + t.Parallel() + + t.Run("not in consensus node", func(t *testing.T) { + t.Parallel() + + wasCalled := false + container := consensus.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{} + appStatusHandler := &statusHandler.AppStatusHandlerStub{ + SetStringValueHandler: func(key string, value string) { + if key == common.MetricConsensusState { + wasCalled = true + assert.Equal(t, "not in consensus group", value) + } + }, + } + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + consensusState.SetSelfPubKey("not in consensus") + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + + srStartRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + srStartRound.Check() + assert.True(t, wasCalled) + }) + t.Run("main key participant", func(t *testing.T) { + t.Parallel() + + wasCalled := false + wasIncrementCalled := false + container := consensus.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return string(pkBytes) == "B" + }, + } + appStatusHandler := &statusHandler.AppStatusHandlerStub{ + SetStringValueHandler: func(key string, value string) { + if key == common.MetricConsensusState { + wasCalled = true + assert.Equal(t, "participant", value) + } + }, + IncrementHandler: func(key string) { + if key == common.MetricCountConsensus { + wasIncrementCalled = true + } + }, + } + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + consensusState.SetSelfPubKey("B") + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + + srStartRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + srStartRound.Check() + assert.True(t, wasCalled) + assert.True(t, wasIncrementCalled) + }) + t.Run("multi key participant", func(t *testing.T) { + t.Parallel() + + wasCalled := false + wasIncrementCalled := false + container := consensus.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{} + appStatusHandler := &statusHandler.AppStatusHandlerStub{ + SetStringValueHandler: func(key string, value string) { + if key == common.MetricConsensusState { + wasCalled = true + assert.Equal(t, "participant", value) + } + }, + IncrementHandler: func(key string) { + if key == common.MetricCountConsensus { + wasIncrementCalled = true + } + }, + } + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return string(pkBytes) == consensusState.SelfPubKey() + } + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + + srStartRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + srStartRound.Check() + assert.True(t, wasCalled) + assert.True(t, wasIncrementCalled) + }) + t.Run("main key leader", func(t *testing.T) { + t.Parallel() + + wasMetricConsensusStateCalled := false + wasMetricCountLeaderCalled := false + cntMetricConsensusRoundStateCalled := 0 + container := consensus.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{} + appStatusHandler := &statusHandler.AppStatusHandlerStub{ + SetStringValueHandler: func(key string, value string) { + if key == common.MetricConsensusState { + wasMetricConsensusStateCalled = true + assert.Equal(t, "proposer", value) + } + if key == common.MetricConsensusRoundState { + cntMetricConsensusRoundStateCalled++ + switch cntMetricConsensusRoundStateCalled { + case 1: + assert.Equal(t, "", value) + case 2: + assert.Equal(t, "proposed", value) + default: + assert.Fail(t, "should have been called only twice") + } + } + }, + IncrementHandler: func(key string) { + if key == common.MetricCountLeader { + wasMetricCountLeaderCalled = true + } + }, + } + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + leader, _ := consensusState.GetLeader() + consensusState.SetSelfPubKey(leader) + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + + srStartRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + srStartRound.Check() + assert.True(t, wasMetricConsensusStateCalled) + assert.True(t, wasMetricCountLeaderCalled) + assert.Equal(t, 2, cntMetricConsensusRoundStateCalled) + }) + t.Run("managed key leader", func(t *testing.T) { + t.Parallel() + + wasMetricConsensusStateCalled := false + wasMetricCountLeaderCalled := false + cntMetricConsensusRoundStateCalled := 0 + container := consensus.InitConsensusCore() + keysHandler := &testscommon.KeysHandlerStub{} + appStatusHandler := &statusHandler.AppStatusHandlerStub{ + SetStringValueHandler: func(key string, value string) { + if key == common.MetricConsensusState { + wasMetricConsensusStateCalled = true + assert.Equal(t, "proposer", value) + } + if key == common.MetricConsensusRoundState { + cntMetricConsensusRoundStateCalled++ + switch cntMetricConsensusRoundStateCalled { + case 1: + assert.Equal(t, "", value) + case 2: + assert.Equal(t, "proposed", value) + default: + assert.Fail(t, "should have been called only twice") + } + } + }, + IncrementHandler: func(key string) { + if key == common.MetricCountLeader { + wasMetricCountLeaderCalled = true + } + }, + } + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusStateWithKeysHandler(keysHandler) + leader, _ := consensusState.GetLeader() + consensusState.SetSelfPubKey(leader) + keysHandler.IsKeyManagedByCurrentNodeCalled = func(pkBytes []byte) bool { + return string(pkBytes) == leader + } + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + appStatusHandler, + ) + + srStartRound, _ := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + srStartRound.Check() + assert.True(t, wasMetricConsensusStateCalled) + assert.True(t, wasMetricCountLeaderCalled) + assert.Equal(t, 2, cntMetricConsensusRoundStateCalled) + }) +} + +func buildDefaultSubround(container spos.ConsensusCoreHandler) *spos.Subround { + ch := make(chan bool, 1) + consensusState := initializers.InitConsensusState() + sr, _ := spos.NewSubround( + -1, + bls.SrStartRound, + bls.SrBlock, + int64(85*roundTimeDuration/100), + int64(95*roundTimeDuration/100), + "(START_ROUND)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + return sr +} + +func TestSubroundStartRound_GenerateNextConsensusGroupShouldErrNilHeader(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + chainHandlerMock := &testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return nil + }, + } + + container.SetBlockchain(chainHandlerMock) + + sr := buildDefaultSubround(container) + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + + err = startRound.GenerateNextConsensusGroup(0) + + assert.Equal(t, spos.ErrNilHeader, err) +} + +func TestSubroundStartRound_InitCurrentRoundShouldReturnFalseWhenResetErr(t *testing.T) { + t.Parallel() + + container := consensus.InitConsensusCore() + + signingHandlerMock := &consensus.SigningHandlerStub{ + ResetCalled: func(pubKeys []string) error { + return expErr + }, + } + + container.SetSigningHandler(signingHandlerMock) + + sr := buildDefaultSubround(container) + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + + r := startRound.InitCurrentRound() + + assert.False(t, r) +} + +func TestSubroundStartRound_IndexRoundIfNeededFailShardIdForEpoch(t *testing.T) { + + pubKeys := []string{"testKey1", "testKey2"} + + container := consensus.InitConsensusCore() + + idVar := 0 + + container.SetShardCoordinator(&processMock.CoordinatorStub{ + SelfIdCalled: func() uint32 { + return uint32(idVar) + }, + }) + + container.SetNodesCoordinator( + &shardingMocks.NodesCoordinatorStub{ + ShardIdForEpochCalled: func(epoch uint32) (uint32, error) { + return 0, expErr + }, + }) + + sr := buildDefaultSubround(container) + + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + + _ = startRound.SetOutportHandler(&outport.OutportStub{ + HasDriversCalled: func() bool { + return true + }, + SaveRoundsInfoCalled: func(roundsInfo *outportcore.RoundsInfo) { + require.Fail(t, "SaveRoundsInfo should not be called") + }, + }) + + startRound.IndexRoundIfNeeded(pubKeys) + +} + +func TestSubroundStartRound_IndexRoundIfNeededGetValidatorsIndexesShouldNotBeCalled(t *testing.T) { + + pubKeys := []string{"testKey1", "testKey2"} + + container := consensus.InitConsensusCore() + + idVar := 0 + + container.SetShardCoordinator(&processMock.CoordinatorStub{ + SelfIdCalled: func() uint32 { + return uint32(idVar) + }, + }) + + container.SetNodesCoordinator( + &shardingMocks.NodesCoordinatorStub{ + GetValidatorsIndexesCalled: func(pubKeys []string, epoch uint32) ([]uint64, error) { + require.Fail(t, "SaveRoundsInfo should not be called") + return nil, expErr + }, + }) + + sr := buildDefaultSubround(container) + + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + + called := false + _ = startRound.SetOutportHandler(&outport.OutportStub{ + HasDriversCalled: func() bool { + return true + }, + SaveRoundsInfoCalled: func(roundsInfo *outportcore.RoundsInfo) { + called = true + }, + }) + + startRound.IndexRoundIfNeeded(pubKeys) + require.True(t, called) +} + +func TestSubroundStartRound_IndexRoundIfNeededShouldFullyWork(t *testing.T) { + + pubKeys := []string{"testKey1", "testKey2"} + + container := consensus.InitConsensusCore() + + idVar := 0 + + saveRoundInfoCalled := false + + container.SetShardCoordinator(&processMock.CoordinatorStub{ + SelfIdCalled: func() uint32 { + return uint32(idVar) + }, + }) + + sr := buildDefaultSubround(container) + + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + + _ = startRound.SetOutportHandler(&outport.OutportStub{ + HasDriversCalled: func() bool { + return true + }, + SaveRoundsInfoCalled: func(roundsInfo *outportcore.RoundsInfo) { + saveRoundInfoCalled = true + }}) + + startRound.IndexRoundIfNeeded(pubKeys) + + assert.True(t, saveRoundInfoCalled) + +} + +func TestSubroundStartRound_IndexRoundIfNeededDifferentShardIdFail(t *testing.T) { + + pubKeys := []string{"testKey1", "testKey2"} + + container := consensus.InitConsensusCore() + + shardID := 1 + container.SetShardCoordinator(&processMock.CoordinatorStub{ + SelfIdCalled: func() uint32 { + return uint32(shardID) + }, + }) + + container.SetNodesCoordinator(&shardingMocks.NodesCoordinatorStub{ + ShardIdForEpochCalled: func(epoch uint32) (uint32, error) { + return 0, nil + }, + }) + + sr := buildDefaultSubround(container) + + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + + _ = startRound.SetOutportHandler(&outport.OutportStub{ + HasDriversCalled: func() bool { + return true + }, + SaveRoundsInfoCalled: func(roundsInfo *outportcore.RoundsInfo) { + require.Fail(t, "SaveRoundsInfo should not be called") + }, + }) + + startRound.IndexRoundIfNeeded(pubKeys) + +} + +func TestSubroundStartRound_changeEpoch(t *testing.T) { + t.Parallel() + + expectPanic := func() { + if recover() == nil { + require.Fail(t, "expected panic") + } + } + + expectNoPanic := func() { + if recover() != nil { + require.Fail(t, "expected no panic") + } + } + + t.Run("error returned by nodes coordinator should error", func(t *testing.T) { + t.Parallel() + + defer expectPanic() + + container := consensus.InitConsensusCore() + exErr := fmt.Errorf("expected error") + container.SetNodesCoordinator( + &shardingMocks.NodesCoordinatorStub{ + GetConsensusWhitelistedNodesCalled: func(epoch uint32) (map[string]struct{}, error) { + return nil, exErr + }, + }) + + sr := buildDefaultSubround(container) + + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + startRound.ChangeEpoch(1) + }) + t.Run("success - no panic", func(t *testing.T) { + t.Parallel() + + defer expectNoPanic() + + container := consensus.InitConsensusCore() + expectedKeys := map[string]struct{}{ + "aaa": {}, + "bbb": {}, + } + + container.SetNodesCoordinator( + &shardingMocks.NodesCoordinatorStub{ + GetConsensusWhitelistedNodesCalled: func(epoch uint32) (map[string]struct{}, error) { + return expectedKeys, nil + }, + }) + + sr := buildDefaultSubround(container) + + startRound, err := v2.NewSubroundStartRound( + sr, + v2.ProcessingThresholdPercent, + &testscommon.SentSignatureTrackerStub{}, + &consensus.SposWorkerMock{}, + ) + require.Nil(t, err) + startRound.ChangeEpoch(1) + }) +} + +func TestSubroundStartRound_GenerateNextConsensusGroupShouldReturnErr(t *testing.T) { + t.Parallel() + + validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{} + + validatorGroupSelector.ComputeValidatorsGroupCalled = func( + bytes []byte, + round uint64, + shardId uint32, + epoch uint32, + ) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + return nil, nil, expErr + } + container := consensus.InitConsensusCore() + container.SetNodesCoordinator(validatorGroupSelector) + + srStartRound := initSubroundStartRoundWithContainer(container) + + err2 := srStartRound.GenerateNextConsensusGroup(0) + + assert.Equal(t, expErr, err2) +} diff --git a/consensus/spos/consensusCore.go b/consensus/spos/consensusCore.go index 2cf7ca369d6..c255d704822 100644 --- a/consensus/spos/consensusCore.go +++ b/consensus/spos/consensusCore.go @@ -4,6 +4,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/epochStart" @@ -38,6 +40,10 @@ type ConsensusCore struct { messageSigningHandler consensus.P2PSigningHandler peerBlacklistHandler consensus.PeerBlacklistHandler signingHandler consensus.SigningHandler + enableEpochsHandler common.EnableEpochsHandler + equivalentProofsPool consensus.EquivalentProofsPool + epochNotifier process.EpochNotifier + invalidSignersCache InvalidSignersCache } // ConsensusCoreArgs store all arguments that are needed to create a ConsensusCore object @@ -64,6 +70,10 @@ type ConsensusCoreArgs struct { MessageSigningHandler consensus.P2PSigningHandler PeerBlacklistHandler consensus.PeerBlacklistHandler SigningHandler consensus.SigningHandler + EnableEpochsHandler common.EnableEpochsHandler + EquivalentProofsPool consensus.EquivalentProofsPool + EpochNotifier process.EpochNotifier + InvalidSignersCache InvalidSignersCache } // NewConsensusCore creates a new ConsensusCore instance @@ -93,6 +103,10 @@ func NewConsensusCore( messageSigningHandler: args.MessageSigningHandler, peerBlacklistHandler: args.PeerBlacklistHandler, signingHandler: args.SigningHandler, + enableEpochsHandler: args.EnableEpochsHandler, + equivalentProofsPool: args.EquivalentProofsPool, + epochNotifier: args.EpochNotifier, + invalidSignersCache: args.InvalidSignersCache, } err := ValidateConsensusCore(consensusCore) @@ -173,6 +187,11 @@ func (cc *ConsensusCore) EpochStartRegistrationHandler() epochStart.Registration return cc.epochStartRegistrationHandler } +// EpochNotifier returns the epoch notifier +func (cc *ConsensusCore) EpochNotifier() process.EpochNotifier { + return cc.epochNotifier +} + // PeerHonestyHandler will return the peer honesty handler which will be used in subrounds func (cc *ConsensusCore) PeerHonestyHandler() consensus.PeerHonestyHandler { return cc.peerHonestyHandler @@ -213,6 +232,151 @@ func (cc *ConsensusCore) SigningHandler() consensus.SigningHandler { return cc.signingHandler } +// EnableEpochsHandler returns the enable epochs handler component +func (cc *ConsensusCore) EnableEpochsHandler() common.EnableEpochsHandler { + return cc.enableEpochsHandler +} + +// EquivalentProofsPool returns the equivalent proofs component +func (cc *ConsensusCore) EquivalentProofsPool() consensus.EquivalentProofsPool { + return cc.equivalentProofsPool +} + +// InvalidSignersCache returns the invalid signers cache component +func (cc *ConsensusCore) InvalidSignersCache() InvalidSignersCache { + return cc.invalidSignersCache +} + +// SetBlockchain sets blockchain handler +func (cc *ConsensusCore) SetBlockchain(blockChain data.ChainHandler) { + cc.blockChain = blockChain +} + +// SetBlockProcessor sets block processor +func (cc *ConsensusCore) SetBlockProcessor(blockProcessor process.BlockProcessor) { + cc.blockProcessor = blockProcessor +} + +// SetBootStrapper sets process bootstrapper +func (cc *ConsensusCore) SetBootStrapper(bootstrapper process.Bootstrapper) { + cc.bootstrapper = bootstrapper +} + +// SetBroadcastMessenger sets broadcast messenger +func (cc *ConsensusCore) SetBroadcastMessenger(broadcastMessenger consensus.BroadcastMessenger) { + cc.broadcastMessenger = broadcastMessenger +} + +// SetChronology sets chronology +func (cc *ConsensusCore) SetChronology(chronologyHandler consensus.ChronologyHandler) { + cc.chronologyHandler = chronologyHandler +} + +// SetHasher sets hasher component +func (cc *ConsensusCore) SetHasher(hasher hashing.Hasher) { + cc.hasher = hasher +} + +// SetMarshalizer sets marshaller component +func (cc *ConsensusCore) SetMarshalizer(marshalizer marshal.Marshalizer) { + cc.marshalizer = marshalizer +} + +// SetMultiSignerContainer sets multi signer container +func (cc *ConsensusCore) SetMultiSignerContainer(multiSignerContainer cryptoCommon.MultiSignerContainer) { + cc.multiSignerContainer = multiSignerContainer +} + +// SetRoundHandler sets round handler +func (cc *ConsensusCore) SetRoundHandler(roundHandler consensus.RoundHandler) { + cc.roundHandler = roundHandler +} + +// SetShardCoordinator set shard coordinator +func (cc *ConsensusCore) SetShardCoordinator(shardCoordinator sharding.Coordinator) { + cc.shardCoordinator = shardCoordinator +} + +// SetSyncTimer sets sync timer +func (cc *ConsensusCore) SetSyncTimer(syncTimer ntp.SyncTimer) { + cc.syncTimer = syncTimer +} + +// SetNodesCoordinator sets nodes coordinaotr +func (cc *ConsensusCore) SetNodesCoordinator(nodesCoordinator nodesCoordinator.NodesCoordinator) { + cc.nodesCoordinator = nodesCoordinator +} + +// SetEpochStartNotifier sets epoch start notifier +func (cc *ConsensusCore) SetEpochStartNotifier(epochStartNotifier epochStart.RegistrationHandler) { + cc.epochStartRegistrationHandler = epochStartNotifier +} + +// SetAntifloodHandler sets antiflood handler +func (cc *ConsensusCore) SetAntifloodHandler(antifloodHandler consensus.P2PAntifloodHandler) { + cc.antifloodHandler = antifloodHandler +} + +// SetPeerHonestyHandler sets peer honesty handler +func (cc *ConsensusCore) SetPeerHonestyHandler(peerHonestyHandler consensus.PeerHonestyHandler) { + cc.peerHonestyHandler = peerHonestyHandler +} + +// SetScheduledProcessor set scheduled processor +func (cc *ConsensusCore) SetScheduledProcessor(scheduledProcessor consensus.ScheduledProcessor) { + cc.scheduledProcessor = scheduledProcessor +} + +// SetPeerBlacklistHandler sets peer blacklist handlerc +func (cc *ConsensusCore) SetPeerBlacklistHandler(peerBlacklistHandler consensus.PeerBlacklistHandler) { + cc.peerBlacklistHandler = peerBlacklistHandler +} + +// SetHeaderSigVerifier sets header sig verifier +func (cc *ConsensusCore) SetHeaderSigVerifier(headerSigVerifier consensus.HeaderSigVerifier) { + cc.headerSigVerifier = headerSigVerifier +} + +// SetFallbackHeaderValidator sets fallback header validaor +func (cc *ConsensusCore) SetFallbackHeaderValidator(fallbackHeaderValidator consensus.FallbackHeaderValidator) { + cc.fallbackHeaderValidator = fallbackHeaderValidator +} + +// SetNodeRedundancyHandler set nodes redundancy handler +func (cc *ConsensusCore) SetNodeRedundancyHandler(nodeRedundancyHandler consensus.NodeRedundancyHandler) { + cc.nodeRedundancyHandler = nodeRedundancyHandler +} + +// SetMessageSigningHandler sets message signing handler +func (cc *ConsensusCore) SetMessageSigningHandler(messageSigningHandler consensus.P2PSigningHandler) { + cc.messageSigningHandler = messageSigningHandler +} + +// SetSigningHandler sets signing handler +func (cc *ConsensusCore) SetSigningHandler(signingHandler consensus.SigningHandler) { + cc.signingHandler = signingHandler +} + +// SetEnableEpochsHandler sets enable eopchs handler +func (cc *ConsensusCore) SetEnableEpochsHandler(enableEpochsHandler common.EnableEpochsHandler) { + cc.enableEpochsHandler = enableEpochsHandler +} + +// SetEquivalentProofsPool sets equivalent proofs pool +func (cc *ConsensusCore) SetEquivalentProofsPool(proofPool consensus.EquivalentProofsPool) { + cc.equivalentProofsPool = proofPool +} + +// SetEpochNotifier sets epoch notifier +func (cc *ConsensusCore) SetEpochNotifier(epochNotifier process.EpochNotifier) { + cc.epochNotifier = epochNotifier +} + +// SetInvalidSignersCache sets the invalid signers cache +func (cc *ConsensusCore) SetInvalidSignersCache(cache InvalidSignersCache) { + cc.invalidSignersCache = cache +} + // IsInterfaceNil returns true if there is no value under the interface func (cc *ConsensusCore) IsInterfaceNil() bool { return cc == nil diff --git a/consensus/spos/consensusCoreValidator.go b/consensus/spos/consensusCoreValidator.go index 239c762f6d3..e3033fa24a9 100644 --- a/consensus/spos/consensusCoreValidator.go +++ b/consensus/spos/consensusCoreValidator.go @@ -1,6 +1,8 @@ package spos -import "github.com/multiversx/mx-chain-core-go/core/check" +import ( + "github.com/multiversx/mx-chain-core-go/core/check" +) // ValidateConsensusCore checks for nil all the container objects func ValidateConsensusCore(container ConsensusCoreHandler) error { @@ -74,6 +76,21 @@ func ValidateConsensusCore(container ConsensusCoreHandler) error { if check.IfNil(container.SigningHandler()) { return ErrNilSigningHandler } + if check.IfNil(container.EnableEpochsHandler()) { + return ErrNilEnableEpochsHandler + } + if check.IfNil(container.EquivalentProofsPool()) { + return ErrNilEquivalentProofPool + } + if check.IfNil(container.EpochNotifier()) { + return ErrNilEpochNotifier + } + if check.IfNil(container.EpochStartRegistrationHandler()) { + return ErrNilEpochStartNotifier + } + if check.IfNil(container.InvalidSignersCache()) { + return ErrNilInvalidSignersCache + } return nil } diff --git a/consensus/spos/consensusCoreValidator_test.go b/consensus/spos/consensusCoreValidator_test.go index acdc008cbe8..f199cd0b7e5 100644 --- a/consensus/spos/consensusCoreValidator_test.go +++ b/consensus/spos/consensusCoreValidator_test.go @@ -1,33 +1,41 @@ -package spos +package spos_test import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/bootstrapperStubs" consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" + epochstartmock "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/assert" ) -func initConsensusDataContainer() *ConsensusCore { +func initConsensusDataContainer() *spos.ConsensusCore { marshalizerMock := mock.MarshalizerMock{} blockChain := &testscommon.ChainHandlerStub{} - blockProcessorMock := mock.InitBlockProcessorMock(marshalizerMock) - bootstrapperMock := &mock.BootstrapperStub{} - broadcastMessengerMock := &mock.BroadcastMessengerMock{} - chronologyHandlerMock := mock.InitChronologyHandlerMock() + blockProcessorMock := consensusMocks.InitBlockProcessorMock(marshalizerMock) + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} + broadcastMessengerMock := &consensusMocks.BroadcastMessengerMock{} + chronologyHandlerMock := consensusMocks.InitChronologyHandlerMock() multiSignerMock := cryptoMocks.NewMultiSigner() hasherMock := &hashingMocks.HasherMock{} - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensusMocks.RoundHandlerMock{} + epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} shardCoordinatorMock := mock.ShardCoordinatorMock{} - syncTimerMock := &mock.SyncTimerMock{} - validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} + nodesCoordinator := &shardingMocks.NodesCoordinatorMock{} antifloodHandler := &mock.P2PAntifloodHandlerStub{} peerHonestyHandler := &testscommon.PeerHonestyHandlerStub{} - headerSigVerifier := &mock.HeaderSigVerifierStub{} + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{} fallbackHeaderValidator := &testscommon.FallBackHeaderValidatorStub{} nodeRedundancyHandler := &mock.NodeRedundancyHandlerStub{} scheduledProcessor := &consensusMocks.ScheduledProcessorStub{} @@ -35,235 +43,353 @@ func initConsensusDataContainer() *ConsensusCore { peerBlacklistHandler := &mock.PeerBlacklistHandlerStub{} multiSignerContainer := cryptoMocks.NewMultiSignerContainerMock(multiSignerMock) signingHandler := &consensusMocks.SigningHandlerStub{} + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} + proofsPool := &dataRetriever.ProofsPoolMock{} + epochNotifier := &epochNotifierMock.EpochNotifierStub{} + invalidSignersCache := &consensusMocks.InvalidSignersCacheMock{} + + consensusCore, _ := spos.NewConsensusCore(&spos.ConsensusCoreArgs{ + BlockChain: blockChain, + BlockProcessor: blockProcessorMock, + Bootstrapper: bootstrapperMock, + BroadcastMessenger: broadcastMessengerMock, + ChronologyHandler: chronologyHandlerMock, + Hasher: hasherMock, + Marshalizer: marshalizerMock, + MultiSignerContainer: multiSignerContainer, + RoundHandler: roundHandlerMock, + ShardCoordinator: shardCoordinatorMock, + SyncTimer: syncTimerMock, + NodesCoordinator: nodesCoordinator, + EpochStartRegistrationHandler: epochStartSubscriber, + AntifloodHandler: antifloodHandler, + PeerHonestyHandler: peerHonestyHandler, + HeaderSigVerifier: headerSigVerifier, + FallbackHeaderValidator: fallbackHeaderValidator, + NodeRedundancyHandler: nodeRedundancyHandler, + ScheduledProcessor: scheduledProcessor, + MessageSigningHandler: messageSigningHandler, + PeerBlacklistHandler: peerBlacklistHandler, + SigningHandler: signingHandler, + EnableEpochsHandler: enableEpochsHandler, + EquivalentProofsPool: proofsPool, + EpochNotifier: epochNotifier, + InvalidSignersCache: invalidSignersCache, + }) + + return consensusCore +} - return &ConsensusCore{ - blockChain: blockChain, - blockProcessor: blockProcessorMock, - bootstrapper: bootstrapperMock, - broadcastMessenger: broadcastMessengerMock, - chronologyHandler: chronologyHandlerMock, - hasher: hasherMock, - marshalizer: marshalizerMock, - multiSignerContainer: multiSignerContainer, - roundHandler: roundHandlerMock, - shardCoordinator: shardCoordinatorMock, - syncTimer: syncTimerMock, - nodesCoordinator: validatorGroupSelector, - antifloodHandler: antifloodHandler, - peerHonestyHandler: peerHonestyHandler, - headerSigVerifier: headerSigVerifier, - fallbackHeaderValidator: fallbackHeaderValidator, - nodeRedundancyHandler: nodeRedundancyHandler, - scheduledProcessor: scheduledProcessor, - messageSigningHandler: messageSigningHandler, - peerBlacklistHandler: peerBlacklistHandler, - signingHandler: signingHandler, - } +func TestConsensusContainerValidator_ValidateNilConsensusCoreFail(t *testing.T) { + t.Parallel() + + err := spos.ValidateConsensusCore(nil) + + assert.Equal(t, spos.ErrNilConsensusCore, err) } func TestConsensusContainerValidator_ValidateNilBlockchainShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.blockChain = nil + container.SetBlockchain(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilBlockChain, err) + assert.Equal(t, spos.ErrNilBlockChain, err) } func TestConsensusContainerValidator_ValidateNilProcessorShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.blockProcessor = nil + container.SetBlockProcessor(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilBlockProcessor, err) + assert.Equal(t, spos.ErrNilBlockProcessor, err) } func TestConsensusContainerValidator_ValidateNilBootstrapperShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.bootstrapper = nil + container.SetBootStrapper(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilBootstrapper, err) + assert.Equal(t, spos.ErrNilBootstrapper, err) } func TestConsensusContainerValidator_ValidateNilChronologyShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.chronologyHandler = nil + container.SetChronology(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilChronologyHandler, err) + assert.Equal(t, spos.ErrNilChronologyHandler, err) } func TestConsensusContainerValidator_ValidateNilHasherShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.hasher = nil + container.SetHasher(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilHasher, err) + assert.Equal(t, spos.ErrNilHasher, err) } func TestConsensusContainerValidator_ValidateNilMarshalizerShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.marshalizer = nil + container.SetMarshalizer(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilMarshalizer, err) + assert.Equal(t, spos.ErrNilMarshalizer, err) } func TestConsensusContainerValidator_ValidateNilMultiSignerContainerShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.multiSignerContainer = nil + container.SetMultiSignerContainer(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilMultiSignerContainer, err) + assert.Equal(t, spos.ErrNilMultiSignerContainer, err) } func TestConsensusContainerValidator_ValidateNilMultiSignerShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.multiSignerContainer = cryptoMocks.NewMultiSignerContainerMock(nil) + container.SetMultiSignerContainer(cryptoMocks.NewMultiSignerContainerMock(nil)) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilMultiSigner, err) + assert.Equal(t, spos.ErrNilMultiSigner, err) } func TestConsensusContainerValidator_ValidateNilRoundHandlerShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.roundHandler = nil + container.SetRoundHandler(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilRoundHandler, err) + assert.Equal(t, spos.ErrNilRoundHandler, err) } func TestConsensusContainerValidator_ValidateNilShardCoordinatorShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.shardCoordinator = nil + container.SetShardCoordinator(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilShardCoordinator, err) + assert.Equal(t, spos.ErrNilShardCoordinator, err) } func TestConsensusContainerValidator_ValidateNilSyncTimerShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.syncTimer = nil + container.SetSyncTimer(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilSyncTimer, err) + assert.Equal(t, spos.ErrNilSyncTimer, err) } func TestConsensusContainerValidator_ValidateNilValidatorGroupSelectorShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.nodesCoordinator = nil + container.SetNodesCoordinator(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilNodesCoordinator, err) + assert.Equal(t, spos.ErrNilNodesCoordinator, err) } func TestConsensusContainerValidator_ValidateNilAntifloodHandlerShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.antifloodHandler = nil + container.SetAntifloodHandler(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilAntifloodHandler, err) + assert.Equal(t, spos.ErrNilAntifloodHandler, err) } func TestConsensusContainerValidator_ValidateNilPeerHonestyHandlerShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.peerHonestyHandler = nil + container.SetPeerHonestyHandler(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilPeerHonestyHandler, err) + assert.Equal(t, spos.ErrNilPeerHonestyHandler, err) } func TestConsensusContainerValidator_ValidateNilHeaderSigVerifierShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.headerSigVerifier = nil + container.SetHeaderSigVerifier(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilHeaderSigVerifier, err) + assert.Equal(t, spos.ErrNilHeaderSigVerifier, err) } func TestConsensusContainerValidator_ValidateNilFallbackHeaderValidatorShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.fallbackHeaderValidator = nil + container.SetFallbackHeaderValidator(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilFallbackHeaderValidator, err) + assert.Equal(t, spos.ErrNilFallbackHeaderValidator, err) } func TestConsensusContainerValidator_ValidateNilNodeRedundancyHandlerShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.nodeRedundancyHandler = nil + container.SetNodeRedundancyHandler(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilNodeRedundancyHandler, err) + assert.Equal(t, spos.ErrNilNodeRedundancyHandler, err) } func TestConsensusContainerValidator_ValidateNilSignatureHandlerShouldFail(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - container.signingHandler = nil + container.SetSigningHandler(nil) + + err := spos.ValidateConsensusCore(container) + + assert.Equal(t, spos.ErrNilSigningHandler, err) +} + +func TestConsensusContainerValidator_ValidateNilEnableEpochsHandlerShouldFail(t *testing.T) { + t.Parallel() + + container := initConsensusDataContainer() + container.SetEnableEpochsHandler(nil) + + err := spos.ValidateConsensusCore(container) + + assert.Equal(t, spos.ErrNilEnableEpochsHandler, err) +} + +func TestConsensusContainerValidator_ValidateNilBroadcastMessengerShouldFail(t *testing.T) { + t.Parallel() + + container := initConsensusDataContainer() + container.SetBroadcastMessenger(nil) + + err := spos.ValidateConsensusCore(container) + + assert.Equal(t, spos.ErrNilBroadcastMessenger, err) +} + +func TestConsensusContainerValidator_ValidateNilScheduledProcessorShouldFail(t *testing.T) { + t.Parallel() + + container := initConsensusDataContainer() + container.SetScheduledProcessor(nil) + + err := spos.ValidateConsensusCore(container) + + assert.Equal(t, spos.ErrNilScheduledProcessor, err) +} + +func TestConsensusContainerValidator_ValidateNilMessageSigningHandlerShouldFail(t *testing.T) { + t.Parallel() + + container := initConsensusDataContainer() + container.SetMessageSigningHandler(nil) + + err := spos.ValidateConsensusCore(container) + + assert.Equal(t, spos.ErrNilMessageSigningHandler, err) +} + +func TestConsensusContainerValidator_ValidateNilPeerBlacklistHandlerShouldFail(t *testing.T) { + t.Parallel() + + container := initConsensusDataContainer() + container.SetPeerBlacklistHandler(nil) + + err := spos.ValidateConsensusCore(container) + + assert.Equal(t, spos.ErrNilPeerBlacklistHandler, err) +} + +func TestConsensusContainerValidator_ValidateNilEquivalentProofPoolShouldFail(t *testing.T) { + t.Parallel() + + container := initConsensusDataContainer() + container.SetEquivalentProofsPool(nil) + + err := spos.ValidateConsensusCore(container) + + assert.Equal(t, spos.ErrNilEquivalentProofPool, err) +} + +func TestConsensusContainerValidator_ValidateNilEpochNotifierShouldFail(t *testing.T) { + t.Parallel() + + container := initConsensusDataContainer() + container.SetEpochNotifier(nil) + + err := spos.ValidateConsensusCore(container) + + assert.Equal(t, spos.ErrNilEpochNotifier, err) +} + +func TestConsensusContainerValidator_ValidateNilEpochStartRegistrationHandlerShouldFail(t *testing.T) { + t.Parallel() + + container := initConsensusDataContainer() + container.SetEpochStartNotifier(nil) + + err := spos.ValidateConsensusCore(container) + + assert.Equal(t, spos.ErrNilEpochStartNotifier, err) +} + +func TestConsensusContainerValidator_ValidateNilInvalidSignersCacheShouldFail(t *testing.T) { + t.Parallel() + + container := initConsensusDataContainer() + container.SetInvalidSignersCache(nil) - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) - assert.Equal(t, ErrNilSigningHandler, err) + assert.Equal(t, spos.ErrNilInvalidSignersCache, err) } func TestConsensusContainerValidator_ShouldWork(t *testing.T) { t.Parallel() container := initConsensusDataContainer() - err := ValidateConsensusCore(container) + err := spos.ValidateConsensusCore(container) assert.Nil(t, err) } diff --git a/consensus/spos/consensusCore_test.go b/consensus/spos/consensusCore_test.go index 2fd67a2cb63..1f44827f857 100644 --- a/consensus/spos/consensusCore_test.go +++ b/consensus/spos/consensusCore_test.go @@ -3,15 +3,15 @@ package spos_test import ( "testing" - "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" - "github.com/stretchr/testify/assert" ) func createDefaultConsensusCoreArgs() *spos.ConsensusCoreArgs { - consensusCoreMock := mock.InitConsensusCore() + consensusCoreMock := consensus.InitConsensusCore() scheduledProcessor := &consensus.ScheduledProcessorStub{} @@ -38,6 +38,10 @@ func createDefaultConsensusCoreArgs() *spos.ConsensusCoreArgs { MessageSigningHandler: consensusCoreMock.MessageSigningHandler(), PeerBlacklistHandler: consensusCoreMock.PeerBlacklistHandler(), SigningHandler: consensusCoreMock.SigningHandler(), + EnableEpochsHandler: consensusCoreMock.EnableEpochsHandler(), + EquivalentProofsPool: consensusCoreMock.EquivalentProofsPool(), + EpochNotifier: consensusCoreMock.EpochNotifier(), + InvalidSignersCache: &consensus.InvalidSignersCacheMock{}, } return args } @@ -334,6 +338,34 @@ func TestConsensusCore_WithNilPeerBlacklistHandlerShouldFail(t *testing.T) { assert.Equal(t, spos.ErrNilPeerBlacklistHandler, err) } +func TestConsensusCore_WithNilEnableEpochsHandlerShouldFail(t *testing.T) { + t.Parallel() + + args := createDefaultConsensusCoreArgs() + args.EnableEpochsHandler = nil + + consensusCore, err := spos.NewConsensusCore( + args, + ) + + assert.Nil(t, consensusCore) + assert.Equal(t, spos.ErrNilEnableEpochsHandler, err) +} + +func TestConsensusCore_WithNilEpochStartRegistrationHandlerShouldFail(t *testing.T) { + t.Parallel() + + args := createDefaultConsensusCoreArgs() + args.EpochStartRegistrationHandler = nil + + consensusCore, err := spos.NewConsensusCore( + args, + ) + + assert.Nil(t, consensusCore) + assert.Equal(t, spos.ErrNilEpochStartNotifier, err) +} + func TestConsensusCore_CreateConsensusCoreShouldWork(t *testing.T) { t.Parallel() diff --git a/consensus/spos/consensusMessageValidator.go b/consensus/spos/consensusMessageValidator.go index 67fa9616e07..c2a63264c75 100644 --- a/consensus/spos/consensusMessageValidator.go +++ b/consensus/spos/consensusMessageValidator.go @@ -7,9 +7,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/p2p" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/sharding" logger "github.com/multiversx/mx-chain-logger-go" ) @@ -17,6 +21,9 @@ type consensusMessageValidator struct { consensusState *ConsensusState consensusService ConsensusService peerSignatureHandler crypto.PeerSignatureHandler + enableEpochsHandler common.EnableEpochsHandler + marshaller marshal.Marshalizer + shardCoordinator sharding.Coordinator signatureSize int publicKeySize int @@ -33,6 +40,9 @@ type ArgsConsensusMessageValidator struct { ConsensusState *ConsensusState ConsensusService ConsensusService PeerSignatureHandler crypto.PeerSignatureHandler + EnableEpochsHandler common.EnableEpochsHandler + Marshaller marshal.Marshalizer + ShardCoordinator sharding.Coordinator SignatureSize int PublicKeySize int HeaderHashSize int @@ -50,6 +60,9 @@ func NewConsensusMessageValidator(args ArgsConsensusMessageValidator) (*consensu consensusState: args.ConsensusState, consensusService: args.ConsensusService, peerSignatureHandler: args.PeerSignatureHandler, + enableEpochsHandler: args.EnableEpochsHandler, + marshaller: args.Marshaller, + shardCoordinator: args.ShardCoordinator, signatureSize: args.SignatureSize, publicKeySize: args.PublicKeySize, chainID: args.ChainID, @@ -69,6 +82,15 @@ func checkArgsConsensusMessageValidator(args ArgsConsensusMessageValidator) erro if check.IfNil(args.PeerSignatureHandler) { return ErrNilPeerSignatureHandler } + if check.IfNil(args.EnableEpochsHandler) { + return ErrNilEnableEpochsHandler + } + if check.IfNil(args.Marshaller) { + return ErrNilMarshalizer + } + if check.IfNil(args.ShardCoordinator) { + return ErrNilShardCoordinator + } if args.ConsensusState == nil { return ErrNilConsensusState } @@ -137,13 +159,13 @@ func (cmv *consensusMessageValidator) checkConsensusMessageValidity(cnsMsg *cons msgType := consensus.MessageType(cnsMsg.MsgType) - if cmv.consensusState.RoundIndex+1 < cnsMsg.RoundIndex { + if cmv.consensusState.GetRoundIndex()+1 < cnsMsg.RoundIndex { log.Trace("received message from consensus topic has a future round", "msg type", cmv.consensusService.GetStringValue(msgType), "from", cnsMsg.PubKey, "header hash", cnsMsg.BlockHeaderHash, "msg round", cnsMsg.RoundIndex, - "round", cmv.consensusState.RoundIndex, + "round", cmv.consensusState.GetRoundIndex(), ) return fmt.Errorf("%w : received message from consensus topic has a future round: %d", @@ -151,13 +173,13 @@ func (cmv *consensusMessageValidator) checkConsensusMessageValidity(cnsMsg *cons cnsMsg.RoundIndex) } - if cmv.consensusState.RoundIndex > cnsMsg.RoundIndex { + if cmv.consensusState.GetRoundIndex() > cnsMsg.RoundIndex { log.Trace("received message from consensus topic has a past round", "msg type", cmv.consensusService.GetStringValue(msgType), "from", cnsMsg.PubKey, "header hash", cnsMsg.BlockHeaderHash, "msg round", cnsMsg.RoundIndex, - "round", cmv.consensusState.RoundIndex, + "round", cmv.consensusState.GetRoundIndex(), ) return fmt.Errorf("%w : received message from consensus topic has a past round: %d", @@ -239,7 +261,19 @@ func (cmv *consensusMessageValidator) checkConsensusMessageValidityForMessageTyp } func (cmv *consensusMessageValidator) checkMessageWithBlockBodyAndHeaderValidity(cnsMsg *consensus.Message) error { - isMessageInvalid := cnsMsg.SignatureShare != nil || + // TODO[cleanup cns finality]: remove this + isInvalidSigShare := cnsMsg.SignatureShare != nil + + header, err := process.UnmarshalHeader(cmv.shardCoordinator.SelfId(), cmv.marshaller, cnsMsg.Header) + if err != nil { + return err + } + + if cmv.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + isInvalidSigShare = cnsMsg.SignatureShare == nil + } + + isMessageInvalid := isInvalidSigShare || cnsMsg.PubKeysBitmap != nil || cnsMsg.AggregateSignature != nil || cnsMsg.LeaderSignature != nil || @@ -306,8 +340,19 @@ func (cmv *consensusMessageValidator) checkMessageWithBlockBodyValidity(cnsMsg * } func (cmv *consensusMessageValidator) checkMessageWithBlockHeaderValidity(cnsMsg *consensus.Message) error { + // TODO[cleanup cns finality]: remove this + isInvalidSigShare := cnsMsg.SignatureShare != nil + + header, err := process.UnmarshalHeader(cmv.shardCoordinator.SelfId(), cmv.marshaller, cnsMsg.Header) + if err != nil { + return err + } + + if cmv.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + isInvalidSigShare = cnsMsg.SignatureShare == nil + } isMessageInvalid := cnsMsg.Body != nil || - cnsMsg.SignatureShare != nil || + isInvalidSigShare || cnsMsg.PubKeysBitmap != nil || cnsMsg.AggregateSignature != nil || cnsMsg.LeaderSignature != nil || @@ -398,6 +443,11 @@ func (cmv *consensusMessageValidator) checkMessageWithFinalInfoValidity(cnsMsg * len(cnsMsg.AggregateSignature)) } + // TODO[cleanup cns finality]: remove this + if cmv.shouldNotVerifyLeaderSignature() { + return nil + } + if len(cnsMsg.LeaderSignature) != cmv.signatureSize { return fmt.Errorf("%w : received leader signature from consensus topic has an invalid size: %d", ErrInvalidSignatureSize, @@ -407,6 +457,16 @@ func (cmv *consensusMessageValidator) checkMessageWithFinalInfoValidity(cnsMsg * return nil } +func (cmv *consensusMessageValidator) shouldNotVerifyLeaderSignature() bool { + // TODO: this check needs to be removed when equivalent messages are sent separately from the final info + if check.IfNil(cmv.consensusState.GetHeader()) { + return true + } + + return cmv.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, cmv.consensusState.GetHeader().GetEpoch()) + +} + func (cmv *consensusMessageValidator) checkMessageWithInvalidSingersValidity(cnsMsg *consensus.Message) error { isMessageInvalid := cnsMsg.SignatureShare != nil || cnsMsg.Body != nil || diff --git a/consensus/spos/consensusMessageValidator_test.go b/consensus/spos/consensusMessageValidator_test.go index 33c37ea4e70..9936694d21f 100644 --- a/consensus/spos/consensusMessageValidator_test.go +++ b/consensus/spos/consensusMessageValidator_test.go @@ -6,13 +6,20 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data/block" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" + "github.com/multiversx/mx-chain-go/testscommon" + testscommonConsensus "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" ) func createDefaultConsensusMessageValidatorArgs() spos.ArgsConsensusMessageValidator { @@ -26,7 +33,7 @@ func createDefaultConsensusMessageValidatorArgs() spos.ArgsConsensusMessageValid return nil }, } - keyGeneratorMock, _, _ := mock.InitKeys() + keyGeneratorMock, _, _ := testscommonConsensus.InitKeys() peerSigHandler := &mock.PeerSignatureHandler{Signer: singleSignerMock, KeyGen: keyGeneratorMock} hasher := &hashingMocks.HasherMock{} @@ -34,6 +41,9 @@ func createDefaultConsensusMessageValidatorArgs() spos.ArgsConsensusMessageValid ConsensusState: consensusState, ConsensusService: blsService, PeerSignatureHandler: peerSigHandler, + EnableEpochsHandler: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), + Marshaller: &marshallerMock.MarshalizerStub{}, + ShardCoordinator: &testscommon.ShardsCoordinatorMock{}, SignatureSize: SignatureSize, PublicKeySize: PublicKeySize, HeaderHashSize: hasher.Size(), @@ -64,6 +74,36 @@ func TestNewConsensusMessageValidator(t *testing.T) { assert.Nil(t, validator) assert.Equal(t, spos.ErrNilPeerSignatureHandler, err) }) + t.Run("nil EnableEpochsHandler", func(t *testing.T) { + t.Parallel() + + args := createDefaultConsensusMessageValidatorArgs() + args.EnableEpochsHandler = nil + validator, err := spos.NewConsensusMessageValidator(args) + + assert.Nil(t, validator) + assert.Equal(t, spos.ErrNilEnableEpochsHandler, err) + }) + t.Run("nil Marshaller", func(t *testing.T) { + t.Parallel() + + args := createDefaultConsensusMessageValidatorArgs() + args.Marshaller = nil + validator, err := spos.NewConsensusMessageValidator(args) + + assert.Nil(t, validator) + assert.Equal(t, spos.ErrNilMarshalizer, err) + }) + t.Run("nil ShardCoordinator", func(t *testing.T) { + t.Parallel() + + args := createDefaultConsensusMessageValidatorArgs() + args.ShardCoordinator = nil + validator, err := spos.NewConsensusMessageValidator(args) + + assert.Nil(t, validator) + assert.Equal(t, spos.ErrNilShardCoordinator, err) + }) t.Run("nil ConsensusState", func(t *testing.T) { t.Parallel() @@ -179,17 +219,55 @@ func TestCheckMessageWithFinalInfoValidity_InvalidAggregateSignatureSize(t *test assert.True(t, errors.Is(err, spos.ErrInvalidSignatureSize)) } -func TestCheckMessageWithFinalInfoValidity_InvalidLeaderSignatureSize(t *testing.T) { +func TestCheckMessageWithFinalInfo_LeaderSignatureCheck(t *testing.T) { t.Parallel() - consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() - cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) + t.Run("should fail", func(t *testing.T) { + t.Parallel() - sig := make([]byte, SignatureSize) - _, _ = rand.Read(sig) - cnsMsg := &consensus.Message{PubKeysBitmap: []byte("01"), AggregateSignature: sig, LeaderSignature: []byte("0")} - err := cmv.CheckMessageWithFinalInfoValidity(cnsMsg) - assert.True(t, errors.Is(err, spos.ErrInvalidSignatureSize)) + consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() + consensusMessageValidatorArgs.ConsensusState.SetHeader(&block.Header{Epoch: 2}) + + sigSize := SignatureSize + consensusMessageValidatorArgs.SignatureSize = sigSize // different signature size + + cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) + + cnsMsg := &consensus.Message{ + MsgType: int64(bls.MtBlockHeaderFinalInfo), + AggregateSignature: make([]byte, SignatureSize), + LeaderSignature: make([]byte, SignatureSize-1), + PubKeysBitmap: []byte("11"), + } + err := cmv.CheckConsensusMessageValidityForMessageType(cnsMsg) + assert.NotNil(t, err) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + + consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() + consensusMessageValidatorArgs.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + consensusMessageValidatorArgs.ConsensusState.SetHeader(&block.Header{Epoch: 2}) + + sigSize := SignatureSize + consensusMessageValidatorArgs.SignatureSize = sigSize // different signature size + + cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) + + cnsMsg := &consensus.Message{ + MsgType: int64(bls.MtBlockHeaderFinalInfo), + AggregateSignature: make([]byte, SignatureSize), + LeaderSignature: make([]byte, SignatureSize-1), + PubKeysBitmap: []byte("11"), + } + err := cmv.CheckConsensusMessageValidityForMessageType(cnsMsg) + assert.Nil(t, err) + }) } func TestCheckMessageWithFinalInfoValidity_ShouldWork(t *testing.T) { @@ -337,6 +415,22 @@ func TestCheckMessageWithBlockBodyValidity_ShouldWork(t *testing.T) { assert.Nil(t, err) } +func TestCheckMessageWithBlockBodyAndHeaderValidity_NilSigShareAfterActivation(t *testing.T) { + t.Parallel() + + consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() + consensusMessageValidatorArgs.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) + + cnsMsg := &consensus.Message{SignatureShare: nil} + err := cmv.CheckMessageWithBlockBodyAndHeaderValidity(cnsMsg) + assert.True(t, errors.Is(err, spos.ErrInvalidMessage)) +} + func TestCheckMessageWithBlockBodyAndHeaderValidity_InvalidMessage(t *testing.T) { t.Parallel() @@ -420,6 +514,22 @@ func TestCheckConsensusMessageValidityForMessageType_MessageWithBlockHeaderInval assert.True(t, errors.Is(err, spos.ErrInvalidMessage)) } +func TestCheckConsensusMessageValidityForMessageType_MessageWithBlockHeaderInvalidAfterFlag(t *testing.T) { + t.Parallel() + + consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() + consensusMessageValidatorArgs.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) + + cnsMsg := &consensus.Message{MsgType: int64(bls.MtBlockHeader), SignatureShare: nil} + err := cmv.CheckConsensusMessageValidityForMessageType(cnsMsg) + assert.True(t, errors.Is(err, spos.ErrInvalidMessage)) +} + func TestCheckConsensusMessageValidityForMessageType_MessageWithSignatureInvalid(t *testing.T) { t.Parallel() @@ -655,7 +765,7 @@ func TestCheckConsensusMessageValidity_ErrMessageForPastRound(t *testing.T) { t.Parallel() consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() - consensusMessageValidatorArgs.ConsensusState.RoundIndex = 100 + consensusMessageValidatorArgs.ConsensusState.SetRoundIndex(100) cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) headerBytes := make([]byte, 100) @@ -678,7 +788,7 @@ func TestCheckConsensusMessageValidity_ErrMessageTypeLimitReached(t *testing.T) t.Parallel() consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() - consensusMessageValidatorArgs.ConsensusState.RoundIndex = 10 + consensusMessageValidatorArgs.ConsensusState.SetRoundIndex(10) cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) pubKey := []byte(consensusMessageValidatorArgs.ConsensusState.ConsensusGroup()[0]) @@ -724,7 +834,7 @@ func createMockConsensusMessage(args spos.ArgsConsensusMessageValidator, pubKey MsgType: int64(msgType), PubKey: pubKey, Signature: createDummyByteSlice(SignatureSize), - RoundIndex: args.ConsensusState.RoundIndex, + RoundIndex: args.ConsensusState.GetRoundIndex(), BlockHeaderHash: createDummyByteSlice(args.HeaderHashSize), } } @@ -743,7 +853,7 @@ func TestCheckConsensusMessageValidity_InvalidSignature(t *testing.T) { consensusMessageValidatorArgs.PeerSignatureHandler = &mock.PeerSignatureHandler{ Signer: signer, } - consensusMessageValidatorArgs.ConsensusState.RoundIndex = 10 + consensusMessageValidatorArgs.ConsensusState.SetRoundIndex(10) cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) headerBytes := make([]byte, 100) @@ -766,7 +876,7 @@ func TestCheckConsensusMessageValidity_Ok(t *testing.T) { t.Parallel() consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() - consensusMessageValidatorArgs.ConsensusState.RoundIndex = 10 + consensusMessageValidatorArgs.ConsensusState.SetRoundIndex(10) cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) headerBytes := make([]byte, 100) diff --git a/consensus/spos/consensusState.go b/consensus/spos/consensusState.go index 564b3def852..476b90133e6 100644 --- a/consensus/spos/consensusState.go +++ b/consensus/spos/consensusState.go @@ -7,24 +7,26 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" - logger "github.com/multiversx/mx-chain-logger-go" ) -// IndexOfLeaderInConsensusGroup represents the index of the leader in the consensus group -const IndexOfLeaderInConsensusGroup = 0 - var log = logger.GetOrCreate("consensus/spos") // ConsensusState defines the data needed by spos to do the consensus in each round type ConsensusState struct { // hold the data on which validators do the consensus (could be for example a hash of the block header // proposed by the leader) - Data []byte - Body data.BodyHandler - Header data.HeaderHandler + Data []byte + + body data.BodyHandler + mutBody sync.RWMutex + + header data.HeaderHandler + mutHeader sync.RWMutex receivedHeaders []data.HeaderHandler mutReceivedHeaders sync.RWMutex @@ -44,6 +46,8 @@ type ConsensusState struct { *roundConsensus *roundThreshold *roundStatus + + mutState sync.RWMutex } // NewConsensusState creates a new ConsensusState object @@ -64,21 +68,25 @@ func NewConsensusState( return &cns } +// ResetConsensusRoundState method resets all the consensus round data (except messages received) +func (cns *ConsensusState) ResetConsensusRoundState() { + cns.RoundCanceled = false + cns.ExtendedCalled = false + cns.WaitingAllSignaturesTimeOut = false + cns.ResetRoundStatus() + cns.ResetRoundState() +} + // ResetConsensusState method resets all the consensus data func (cns *ConsensusState) ResetConsensusState() { - cns.Body = nil - cns.Header = nil + cns.SetBody(nil) + cns.SetHeader(nil) cns.Data = nil cns.initReceivedHeaders() cns.initReceivedMessagesWithSig() - cns.RoundCanceled = false - cns.ExtendedCalled = false - cns.WaitingAllSignaturesTimeOut = false - - cns.ResetRoundStatus() - cns.ResetRoundState() + cns.ResetConsensusRoundState() } func (cns *ConsensusState) initReceivedHeaders() { @@ -136,11 +144,6 @@ func (cns *ConsensusState) IsNodeLeaderInCurrentRound(node string) bool { return leader == node } -// IsSelfLeaderInCurrentRound method checks if the current node is leader in the current round -func (cns *ConsensusState) IsSelfLeaderInCurrentRound() bool { - return cns.IsNodeLeaderInCurrentRound(cns.selfPubKey) -} - // GetLeader method gets the leader of the current round func (cns *ConsensusState) GetLeader() (string, error) { if cns.consensusGroup == nil { @@ -151,7 +154,7 @@ func (cns *ConsensusState) GetLeader() (string, error) { return "", ErrEmptyConsensusGroup } - return cns.consensusGroup[IndexOfLeaderInConsensusGroup], nil + return cns.Leader(), nil } // GetNextConsensusGroup gets the new consensus group for the current round based on current eligible list and a random @@ -162,8 +165,8 @@ func (cns *ConsensusState) GetNextConsensusGroup( shardId uint32, nodesCoordinator nodesCoordinator.NodesCoordinator, epoch uint32, -) ([]string, error) { - validatorsGroup, err := nodesCoordinator.ComputeConsensusGroup(randomSource, round, shardId, epoch) +) (string, []string, error) { + leader, validatorsGroup, err := nodesCoordinator.ComputeConsensusGroup(randomSource, round, shardId, epoch) if err != nil { log.Debug( "compute consensus group", @@ -173,7 +176,7 @@ func (cns *ConsensusState) GetNextConsensusGroup( "shardId", shardId, "epoch", epoch, ) - return nil, err + return "", nil, err } consensusSize := len(validatorsGroup) @@ -183,7 +186,7 @@ func (cns *ConsensusState) GetNextConsensusGroup( newConsensusGroup[i] = string(validatorsGroup[i].PubKey()) } - return newConsensusGroup, nil + return string(leader.PubKey()), newConsensusGroup, nil } // IsConsensusDataSet method returns true if the consensus data for the current round is set and false otherwise @@ -212,11 +215,6 @@ func (cns *ConsensusState) IsJobDone(node string, currentSubroundId int) bool { return jobDone } -// IsSelfJobDone method returns true if self job for the current subround is done and false otherwise -func (cns *ConsensusState) IsSelfJobDone(currentSubroundId int) bool { - return cns.IsJobDone(cns.selfPubKey, currentSubroundId) -} - // IsSubroundFinished method returns true if the current subround is finished and false otherwise func (cns *ConsensusState) IsSubroundFinished(subroundID int) bool { isSubroundFinished := cns.Status(subroundID) == SsFinished @@ -233,14 +231,14 @@ func (cns *ConsensusState) IsNodeSelf(node string) bool { // IsBlockBodyAlreadyReceived method returns true if block body is already received and false otherwise func (cns *ConsensusState) IsBlockBodyAlreadyReceived() bool { - isBlockBodyAlreadyReceived := cns.Body != nil + isBlockBodyAlreadyReceived := cns.GetBody() != nil return isBlockBodyAlreadyReceived } // IsHeaderAlreadyReceived method returns true if header is already received and false otherwise func (cns *ConsensusState) IsHeaderAlreadyReceived() bool { - isHeaderAlreadyReceived := cns.Header != nil + isHeaderAlreadyReceived := cns.GetHeader() != nil return isHeaderAlreadyReceived } @@ -251,16 +249,7 @@ func (cns *ConsensusState) CanDoSubroundJob(currentSubroundId int) bool { return false } - selfJobDone := true - if cns.IsNodeInConsensusGroup(cns.SelfPubKey()) { - selfJobDone = cns.IsSelfJobDone(currentSubroundId) - } - multiKeyJobDone := true - if cns.IsMultiKeyInConsensusGroup() { - multiKeyJobDone = cns.IsMultiKeyJobDone(currentSubroundId) - } - - if selfJobDone && multiKeyJobDone { + if cns.IsSelfJobDone(currentSubroundId) { return false } @@ -341,6 +330,11 @@ func (cns *ConsensusState) GetData() []byte { return cns.Data } +// SetData sets the Data of the consensusState +func (cns *ConsensusState) SetData(data []byte) { + cns.Data = data +} + // IsMultiKeyLeaderInCurrentRound method checks if one of the nodes which are controlled by this instance // is leader in the current round func (cns *ConsensusState) IsMultiKeyLeaderInCurrentRound() bool { @@ -350,7 +344,7 @@ func (cns *ConsensusState) IsMultiKeyLeaderInCurrentRound() bool { return false } - return cns.IsKeyManagedByCurrentNode([]byte(leader)) + return cns.IsKeyManagedBySelf([]byte(leader)) } // IsLeaderJobDone method returns true if the leader job for the current subround is done and false otherwise @@ -380,6 +374,21 @@ func (cns *ConsensusState) IsMultiKeyJobDone(currentSubroundId int) bool { return true } +// IsSelfJobDone method returns true if self job for the current subround is done and false otherwise +func (cns *ConsensusState) IsSelfJobDone(currentSubroundID int) bool { + selfJobDone := true + if cns.IsNodeInConsensusGroup(cns.SelfPubKey()) { + selfJobDone = cns.IsJobDone(cns.SelfPubKey(), currentSubroundID) + } + + multiKeyJobDone := true + if cns.IsMultiKeyInConsensusGroup() { + multiKeyJobDone = cns.IsMultiKeyJobDone(currentSubroundID) + } + + return selfJobDone && multiKeyJobDone +} + // GetMultikeyRedundancyStepInReason returns the reason if the current node stepped in as a multikey redundancy node func (cns *ConsensusState) GetMultikeyRedundancyStepInReason() string { return cns.keysHandler.GetRedundancyStepInReason() @@ -390,3 +399,108 @@ func (cns *ConsensusState) GetMultikeyRedundancyStepInReason() string { func (cns *ConsensusState) ResetRoundsWithoutReceivedMessages(pkBytes []byte, pid core.PeerID) { cns.keysHandler.ResetRoundsWithoutReceivedMessages(pkBytes, pid) } + +// GetRoundCanceled returns the state of the current round +func (cns *ConsensusState) GetRoundCanceled() bool { + cns.mutState.RLock() + defer cns.mutState.RUnlock() + + return cns.RoundCanceled +} + +// SetRoundCanceled sets the state of the current round +func (cns *ConsensusState) SetRoundCanceled(roundCanceled bool) { + cns.mutState.Lock() + defer cns.mutState.Unlock() + + cns.RoundCanceled = roundCanceled +} + +// GetRoundIndex returns the index of the current round +func (cns *ConsensusState) GetRoundIndex() int64 { + cns.mutState.RLock() + defer cns.mutState.RUnlock() + + return cns.RoundIndex +} + +// SetRoundIndex sets the index of the current round +func (cns *ConsensusState) SetRoundIndex(roundIndex int64) { + cns.mutState.Lock() + defer cns.mutState.Unlock() + + cns.RoundIndex = roundIndex +} + +// GetRoundTimeStamp returns the time stamp of the current round +func (cns *ConsensusState) GetRoundTimeStamp() time.Time { + return cns.RoundTimeStamp +} + +// SetRoundTimeStamp sets the time stamp of the current round +func (cns *ConsensusState) SetRoundTimeStamp(roundTimeStamp time.Time) { + cns.RoundTimeStamp = roundTimeStamp +} + +// GetExtendedCalled returns the state of the extended called +func (cns *ConsensusState) GetExtendedCalled() bool { + return cns.ExtendedCalled +} + +// SetExtendedCalled sets the state of the extended called +func (cns *ConsensusState) SetExtendedCalled(extendedCalled bool) { + cns.ExtendedCalled = extendedCalled +} + +// GetBody returns the body of the current round +func (cns *ConsensusState) GetBody() data.BodyHandler { + cns.mutBody.RLock() + defer cns.mutBody.RUnlock() + + return cns.body +} + +// SetBody sets the body of the current round +func (cns *ConsensusState) SetBody(body data.BodyHandler) { + cns.mutBody.Lock() + defer cns.mutBody.Unlock() + + cns.body = body +} + +// GetHeader returns the header of the current round +func (cns *ConsensusState) GetHeader() data.HeaderHandler { + cns.mutHeader.RLock() + defer cns.mutHeader.RUnlock() + + return cns.header +} + +// SetHeader sets the header of the current round +func (cns *ConsensusState) SetHeader(header data.HeaderHandler) { + cns.mutHeader.Lock() + defer cns.mutHeader.Unlock() + + cns.header = header +} + +// GetWaitingAllSignaturesTimeOut returns the state of the waiting all signatures time out +func (cns *ConsensusState) GetWaitingAllSignaturesTimeOut() bool { + cns.mutState.RLock() + defer cns.mutState.RUnlock() + + return cns.WaitingAllSignaturesTimeOut +} + +// SetWaitingAllSignaturesTimeOut sets the state of the waiting all signatures time out +func (cns *ConsensusState) SetWaitingAllSignaturesTimeOut(waitingAllSignaturesTimeOut bool) { + cns.mutState.Lock() + defer cns.mutState.Unlock() + + cns.WaitingAllSignaturesTimeOut = waitingAllSignaturesTimeOut +} + +// IsInterfaceNil returns true if there is no value under the interface +func (cns *ConsensusState) IsInterfaceNil() bool { + return cns == nil +} diff --git a/consensus/spos/consensusState_test.go b/consensus/spos/consensusState_test.go index 554c9c0c755..62855a79ea9 100644 --- a/consensus/spos/consensusState_test.go +++ b/consensus/spos/consensusState_test.go @@ -4,16 +4,20 @@ import ( "bytes" "errors" "testing" + "time" + p2pMessage "github.com/multiversx/mx-chain-communication-go/p2p/message" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/assert" ) func internalInitConsensusState() *spos.ConsensusState { @@ -36,6 +40,7 @@ func internalInitConsensusStateWithKeysHandler(keysHandler consensus.KeysHandler ) rcns.SetConsensusGroup(eligibleList) + rcns.SetLeader(eligibleList[0]) rcns.ResetRoundState() rthr := spos.NewRoundThreshold() @@ -68,12 +73,13 @@ func TestConsensusState_ResetConsensusStateShouldWork(t *testing.T) { t.Parallel() cns := internalInitConsensusState() - cns.RoundCanceled = true - cns.ExtendedCalled = true - cns.WaitingAllSignaturesTimeOut = true + cns.SetRoundCanceled(true) + require.True(t, cns.GetRoundCanceled()) + cns.SetExtendedCalled(true) + cns.SetWaitingAllSignaturesTimeOut(true) cns.ResetConsensusState() assert.False(t, cns.RoundCanceled) - assert.False(t, cns.ExtendedCalled) + assert.False(t, cns.GetExtendedCalled()) assert.False(t, cns.WaitingAllSignaturesTimeOut) } @@ -102,22 +108,6 @@ func TestConsensusState_IsNodeLeaderInCurrentRoundShouldReturnTrue(t *testing.T) assert.Equal(t, true, cns.IsNodeLeaderInCurrentRound("1")) } -func TestConsensusState_IsSelfLeaderInCurrentRoundShouldReturnFalse(t *testing.T) { - t.Parallel() - - cns := internalInitConsensusState() - - assert.False(t, cns.IsSelfLeaderInCurrentRound()) -} - -func TestConsensusState_IsSelfLeaderInCurrentRoundShouldReturnTrue(t *testing.T) { - t.Parallel() - - cns := internalInitConsensusState() - - assert.False(t, cns.IsSelfLeaderInCurrentRound()) -} - func TestConsensusState_GetLeaderShoudErrNilConsensusGroup(t *testing.T) { t.Parallel() @@ -162,11 +152,11 @@ func TestConsensusState_GetNextConsensusGroupShouldFailWhenComputeValidatorsGrou round uint64, shardId uint32, epoch uint32, - ) ([]nodesCoordinator.Validator, error) { - return nil, err + ) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { + return nil, nil, err } - _, err2 := cns.GetNextConsensusGroup([]byte(""), 0, 0, nodesCoord, 0) + _, _, err2 := cns.GetNextConsensusGroup([]byte(""), 0, 0, nodesCoord, 0) assert.Equal(t, err, err2) } @@ -176,10 +166,11 @@ func TestConsensusState_GetNextConsensusGroupShouldWork(t *testing.T) { cns := internalInitConsensusState() nodesCoord := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]nodesCoordinator.Validator, error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { defaultSelectionChances := uint32(1) - return []nodesCoordinator.Validator{ - shardingMocks.NewValidatorMock([]byte("A"), 1, defaultSelectionChances), + leader := shardingMocks.NewValidatorMock([]byte("A"), 1, defaultSelectionChances) + return leader, []nodesCoordinator.Validator{ + leader, shardingMocks.NewValidatorMock([]byte("B"), 1, defaultSelectionChances), shardingMocks.NewValidatorMock([]byte("C"), 1, defaultSelectionChances), shardingMocks.NewValidatorMock([]byte("D"), 1, defaultSelectionChances), @@ -192,9 +183,10 @@ func TestConsensusState_GetNextConsensusGroupShouldWork(t *testing.T) { }, } - nextConsensusGroup, err := cns.GetNextConsensusGroup(nil, 0, 0, nodesCoord, 0) + leader, nextConsensusGroup, err := cns.GetNextConsensusGroup(nil, 0, 0, nodesCoord, 0) assert.Nil(t, err) assert.NotNil(t, nextConsensusGroup) + assert.NotEmpty(t, leader) } func TestConsensusState_IsConsensusDataSetShouldReturnTrue(t *testing.T) { @@ -334,7 +326,7 @@ func TestConsensusState_IsBlockBodyAlreadyReceivedShouldReturnFalse(t *testing.T cns := internalInitConsensusState() - cns.Body = nil + cns.SetBody(nil) assert.False(t, cns.IsBlockBodyAlreadyReceived()) } @@ -344,7 +336,7 @@ func TestConsensusState_IsBlockBodyAlreadyReceivedShouldReturnTrue(t *testing.T) cns := internalInitConsensusState() - cns.Body = &block.Body{} + cns.SetBody(&block.Body{}) assert.True(t, cns.IsBlockBodyAlreadyReceived()) } @@ -354,7 +346,7 @@ func TestConsensusState_IsHeaderAlreadyReceivedShouldReturnFalse(t *testing.T) { cns := internalInitConsensusState() - cns.Header = nil + cns.SetHeader(nil) assert.False(t, cns.IsHeaderAlreadyReceived()) } @@ -364,7 +356,7 @@ func TestConsensusState_IsHeaderAlreadyReceivedShouldReturnTrue(t *testing.T) { cns := internalInitConsensusState() - cns.Header = &block.Header{} + cns.SetHeader(&block.Header{}) assert.True(t, cns.IsHeaderAlreadyReceived()) } @@ -617,3 +609,59 @@ func TestConsensusState_ResetRoundsWithoutReceivedMessages(t *testing.T) { cns.ResetRoundsWithoutReceivedMessages(testPkBytes, testPid) assert.True(t, resetRoundsWithoutReceivedMessagesCalled) } + +func TestConsensusState_GettersSetters(t *testing.T) { + t.Parallel() + + keysHandler := &testscommon.KeysHandlerStub{} + cns := internalInitConsensusStateWithKeysHandler(keysHandler) + + providedIndex := int64(123) + cns.SetRoundIndex(providedIndex) + require.Equal(t, providedIndex, cns.GetRoundIndex()) + + providedTimestamp := time.Now() + cns.SetRoundTimeStamp(providedTimestamp) + require.Equal(t, providedTimestamp, cns.GetRoundTimeStamp()) + + cns.SetExtendedCalled(true) + require.True(t, cns.GetExtendedCalled()) + + providedBody := &block.Body{} + cns.SetBody(providedBody) + require.Equal(t, providedBody, cns.GetBody()) + + providedHeader := &block.Header{} + cns.SetHeader(providedHeader) + require.Equal(t, providedHeader, cns.GetHeader()) + + cns.SetWaitingAllSignaturesTimeOut(true) + require.True(t, cns.GetWaitingAllSignaturesTimeOut()) + + providedData := []byte("hash") + cns.SetData(providedData) + require.Equal(t, string(providedData), string(cns.GetData())) + + cns.AddReceivedHeader(providedHeader) + receivedHeaders := cns.GetReceivedHeaders() + require.Equal(t, 1, len(receivedHeaders)) + require.Equal(t, providedHeader, receivedHeaders[0]) + + providedMsg := &p2pMessage.Message{} + providedKey := "key" + cns.AddMessageWithSignature(providedKey, providedMsg) + msg, ok := cns.GetMessageWithSignature(providedKey) + require.True(t, ok) + require.Equal(t, providedMsg, msg) +} + +func TestConsensusState_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var cns *spos.ConsensusState + require.True(t, cns.IsInterfaceNil()) + + keysHandler := &testscommon.KeysHandlerStub{} + cns = internalInitConsensusStateWithKeysHandler(keysHandler) + require.False(t, cns.IsInterfaceNil()) +} diff --git a/consensus/spos/errors.go b/consensus/spos/errors.go index 3aeac029da3..d89a58865f3 100644 --- a/consensus/spos/errors.go +++ b/consensus/spos/errors.go @@ -243,3 +243,42 @@ var ErrNilFunctionHandler = errors.New("nil function handler") // ErrWrongHashForHeader signals that the hash of the header is not the expected one var ErrWrongHashForHeader = errors.New("wrong hash for header") + +// ErrNilSentSignatureTracker defines the error for setting a nil SentSignatureTracker +var ErrNilSentSignatureTracker = errors.New("nil sent signature tracker") + +// ErrEquivalentMessageAlreadyReceived signals that an equivalent message has been already received +var ErrEquivalentMessageAlreadyReceived = errors.New("equivalent message already received") + +// ErrNilEnableEpochsHandler signals that a nil enable epochs handler has been provided +var ErrNilEnableEpochsHandler = errors.New("nil enable epochs handler") + +// ErrNilThrottler signals that a nil throttler has been provided +var ErrNilThrottler = errors.New("nil throttler") + +// ErrTimeIsOut signals that time is out +var ErrTimeIsOut = errors.New("time is out") + +// ErrNilEquivalentProofPool signals that a nil proof pool has been provided +var ErrNilEquivalentProofPool = errors.New("nil equivalent proof pool") + +// ErrNilHeaderProof signals that a nil header proof has been provided +var ErrNilHeaderProof = errors.New("nil header proof") + +// ErrHeaderProofNotExpected signals that a header proof was not expected +var ErrHeaderProofNotExpected = errors.New("header proof not expected") + +// ErrConsensusMessageNotExpected signals that a consensus message was not expected +var ErrConsensusMessageNotExpected = errors.New("consensus message not expected") + +// ErrNilEpochNotifier signals that a nil epoch notifier has been provided +var ErrNilEpochNotifier = errors.New("nil epoch notifier") + +// ErrNilEpochStartNotifier signals that nil epoch start notifier has been provided +var ErrNilEpochStartNotifier = errors.New("nil epoch start notifier") + +// ErrInvalidSignersAlreadyReceived signals that an invalid signers message has been already received +var ErrInvalidSignersAlreadyReceived = errors.New("invalid signers already received") + +// ErrNilInvalidSignersCache signals that nil invalid signers has been provided +var ErrNilInvalidSignersCache = errors.New("nil invalid signers cache") diff --git a/consensus/spos/export_test.go b/consensus/spos/export_test.go index 39d19de6e30..6ada6ceccde 100644 --- a/consensus/spos/export_test.go +++ b/consensus/spos/export_test.go @@ -3,9 +3,13 @@ package spos import ( "context" "fmt" + "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/process" ) @@ -13,6 +17,12 @@ import ( // RedundancySingleKeySteppedIn exposes the redundancySingleKeySteppedIn constant const RedundancySingleKeySteppedIn = redundancySingleKeySteppedIn +// LeaderSingleKeyStartMsg - +const LeaderSingleKeyStartMsg = singleKeyStartMsg + +// LeaderMultiKeyStartMsg - +const LeaderMultiKeyStartMsg = multiKeyStartMsg + type RoundConsensus struct { *roundConsensus } @@ -142,17 +152,17 @@ func (wrk *Worker) NilReceivedMessages() { } // ReceivedMessagesCalls - -func (wrk *Worker) ReceivedMessagesCalls() map[consensus.MessageType]func(context.Context, *consensus.Message) bool { +func (wrk *Worker) ReceivedMessagesCalls() map[consensus.MessageType][]func(context.Context, *consensus.Message) bool { wrk.mutReceivedMessagesCalls.RLock() defer wrk.mutReceivedMessagesCalls.RUnlock() return wrk.receivedMessagesCalls } -// SetReceivedMessagesCalls - -func (wrk *Worker) SetReceivedMessagesCalls(messageType consensus.MessageType, f func(context.Context, *consensus.Message) bool) { +// AppendReceivedMessagesCalls - +func (wrk *Worker) AppendReceivedMessagesCalls(messageType consensus.MessageType, f func(context.Context, *consensus.Message) bool) { wrk.mutReceivedMessagesCalls.Lock() - wrk.receivedMessagesCalls[messageType] = f + wrk.receivedMessagesCalls[messageType] = append(wrk.receivedMessagesCalls[messageType], f) wrk.mutReceivedMessagesCalls.Unlock() } @@ -161,6 +171,26 @@ func (wrk *Worker) ExecuteMessageChannel() chan *consensus.Message { return wrk.executeMessageChannel } +// ConvertHeaderToConsensusMessage - +func (wrk *Worker) ConvertHeaderToConsensusMessage(header data.HeaderHandler) (*consensus.Message, error) { + return wrk.convertHeaderToConsensusMessage(header) +} + +// Hasher - +func (wrk *Worker) Hasher() data.Hasher { + return wrk.hasher +} + +// SetEnableEpochsHandler +func (wrk *Worker) SetEnableEpochsHandler(enableEpochsHandler common.EnableEpochsHandler) { + wrk.enableEpochsHandler = enableEpochsHandler +} + +// AddFutureHeaderToProcessIfNeeded - +func (wrk *Worker) AddFutureHeaderToProcessIfNeeded(header data.HeaderHandler) { + wrk.addFutureHeaderToProcessIfNeeded(header) +} + // ConsensusStateChangedChannel - func (wrk *Worker) ConsensusStateChangedChannel() chan bool { return wrk.consensusStateChangedChannel @@ -265,3 +295,61 @@ func (cmv *consensusMessageValidator) GetNumOfMessageTypeForPublicKey(pk []byte, func (cmv *consensusMessageValidator) ResetConsensusMessages() { cmv.resetConsensusMessages() } + +// SetStatus - +func (sp *scheduledProcessorWrapper) SetStatus(status processingStatus) { + sp.setStatus(status) +} + +// GetStatus - +func (sp *scheduledProcessorWrapper) GetStatus() processingStatus { + return sp.getStatus() +} + +// SetStartTime - +func (sp *scheduledProcessorWrapper) SetStartTime(t time.Time) { + sp.startTime = t +} + +// GetStartTime - +func (sp *scheduledProcessorWrapper) GetStartTime() time.Time { + return sp.startTime +} + +// GetRoundTimeHandler - +func (sp *scheduledProcessorWrapper) GetRoundTimeHandler() process.RoundTimeDurationHandler { + return sp.roundTimeDurationHandler +} + +// ProcessingNotStarted - +var ProcessingNotStarted = processingNotStarted + +// ProcessingError - +var ProcessingError = processingError + +// InProgress - +var InProgress = inProgress + +// ProcessingOK - +var ProcessingOK = processingOK + +// Stopped - +var Stopped = stopped + +// ProcessingNotStartedString - +var ProcessingNotStartedString = processingNotStartedString + +// ProcessingErrorString - +var ProcessingErrorString = processingErrorString + +// InProgressString - +var InProgressString = inProgressString + +// ProcessingOKString - +var ProcessingOKString = processingOKString + +// StoppedString - +var StoppedString = stoppedString + +// UnexpectedString - +var UnexpectedString = unexpectedString diff --git a/consensus/spos/interface.go b/consensus/spos/interface.go index 0ca771d30e5..1decc6aabc3 100644 --- a/consensus/spos/interface.go +++ b/consensus/spos/interface.go @@ -9,6 +9,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data/outport" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/epochStart" @@ -21,51 +23,32 @@ import ( // ConsensusCoreHandler encapsulates all needed data for the Consensus type ConsensusCoreHandler interface { - // Blockchain gets the ChainHandler stored in the ConsensusCore Blockchain() data.ChainHandler - // BlockProcessor gets the BlockProcessor stored in the ConsensusCore BlockProcessor() process.BlockProcessor - // BootStrapper gets the Bootstrapper stored in the ConsensusCore BootStrapper() process.Bootstrapper - // BroadcastMessenger gets the BroadcastMessenger stored in ConsensusCore BroadcastMessenger() consensus.BroadcastMessenger - // Chronology gets the ChronologyHandler stored in the ConsensusCore Chronology() consensus.ChronologyHandler - // GetAntiFloodHandler returns the antiflood handler which will be used in subrounds GetAntiFloodHandler() consensus.P2PAntifloodHandler - // Hasher gets the Hasher stored in the ConsensusCore Hasher() hashing.Hasher - // Marshalizer gets the Marshalizer stored in the ConsensusCore Marshalizer() marshal.Marshalizer - // MultiSignerContainer gets the MultiSigner container from the ConsensusCore MultiSignerContainer() cryptoCommon.MultiSignerContainer - // RoundHandler gets the RoundHandler stored in the ConsensusCore RoundHandler() consensus.RoundHandler - // ShardCoordinator gets the ShardCoordinator stored in the ConsensusCore ShardCoordinator() sharding.Coordinator - // SyncTimer gets the SyncTimer stored in the ConsensusCore SyncTimer() ntp.SyncTimer - // NodesCoordinator gets the NodesCoordinator stored in the ConsensusCore NodesCoordinator() nodesCoordinator.NodesCoordinator - // EpochStartRegistrationHandler gets the RegistrationHandler stored in the ConsensusCore EpochStartRegistrationHandler() epochStart.RegistrationHandler - // PeerHonestyHandler returns the peer honesty handler which will be used in subrounds PeerHonestyHandler() consensus.PeerHonestyHandler - // HeaderSigVerifier returns the sig verifier handler which will be used in subrounds HeaderSigVerifier() consensus.HeaderSigVerifier - // FallbackHeaderValidator returns the fallback header validator handler which will be used in subrounds FallbackHeaderValidator() consensus.FallbackHeaderValidator - // NodeRedundancyHandler returns the node redundancy handler which will be used in subrounds NodeRedundancyHandler() consensus.NodeRedundancyHandler - // ScheduledProcessor returns the scheduled txs processor ScheduledProcessor() consensus.ScheduledProcessor - // MessageSigningHandler returns the p2p signing handler MessageSigningHandler() consensus.P2PSigningHandler - // PeerBlacklistHandler return the peer blacklist handler PeerBlacklistHandler() consensus.PeerBlacklistHandler - // SigningHandler returns the signing handler component SigningHandler() consensus.SigningHandler - // IsInterfaceNil returns true if there is no value under the interface + EnableEpochsHandler() common.EnableEpochsHandler + EquivalentProofsPool() consensus.EquivalentProofsPool + EpochNotifier() process.EpochNotifier + InvalidSignersCache() InvalidSignersCache IsInterfaceNil() bool } @@ -104,17 +87,12 @@ type ConsensusService interface { GetMaxMessagesInARoundPerPeer() uint32 // GetMaxNumOfMessageTypeAccepted returns the maximum number of accepted consensus message types per round, per public key GetMaxNumOfMessageTypeAccepted(msgType consensus.MessageType) uint32 + // GetMessageTypeBlockHeader returns the message type for the block header + GetMessageTypeBlockHeader() consensus.MessageType // IsInterfaceNil returns true if there is no value under the interface IsInterfaceNil() bool } -// SubroundsFactory encapsulates the methods specifically for a subrounds factory type (bls, bn) -// for different consensus types -type SubroundsFactory interface { - GenerateSubrounds() error - IsInterfaceNil() bool -} - // WorkerHandler represents the interface for the SposWorker type WorkerHandler interface { Close() error @@ -123,10 +101,14 @@ type WorkerHandler interface { AddReceivedMessageCall(messageType consensus.MessageType, receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool) // AddReceivedHeaderHandler adds a new handler function for a received header AddReceivedHeaderHandler(handler func(data.HeaderHandler)) + // RemoveAllReceivedHeaderHandlers removes all the functions handlers + RemoveAllReceivedHeaderHandlers() + // AddReceivedProofHandler adds a new handler function for a received proof + AddReceivedProofHandler(handler func(consensus.ProofHandler)) // RemoveAllReceivedMessagesCalls removes all the functions handlers RemoveAllReceivedMessagesCalls() // ProcessReceivedMessage method redirects the received message to the channel which should handle it - ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error + ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) // Extend does an extension for the subround with subroundId Extend(subroundId int) // GetConsensusStateChangedChannel gets the channel for the consensusStateChanged @@ -137,8 +119,12 @@ type WorkerHandler interface { DisplayStatistics() // ReceivedHeader method is a wired method through which worker will receive headers from network ReceivedHeader(headerHandler data.HeaderHandler, headerHash []byte) - // ResetConsensusMessages resets at the start of each round all the previous consensus messages received + // ResetConsensusMessages resets at the start of each round all the previous consensus messages received and equivalent messages, keeping the provided proofs ResetConsensusMessages() + // ResetConsensusRoundState resets the consensus round state when transitioning to a different consensus version + ResetConsensusRoundState() + // ResetInvalidSignersCache resets the invalid signers cache + ResetInvalidSignersCache() // IsInterfaceNil returns true if there is no value under the interface IsInterfaceNil() bool } @@ -154,6 +140,8 @@ type HeaderSigVerifier interface { VerifyRandSeed(header data.HeaderHandler) error VerifyLeaderSignature(header data.HeaderHandler) error VerifySignature(header data.HeaderHandler) error + VerifySignatureForHash(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error + VerifyHeaderProof(headerProof data.HeaderProofHandler) error IsInterfaceNil() bool } @@ -177,3 +165,108 @@ type SentSignaturesTracker interface { SignatureSent(pkBytes []byte) IsInterfaceNil() bool } + +// ConsensusStateHandler encapsulates all needed data for the Consensus +type ConsensusStateHandler interface { + ResetConsensusState() + ResetConsensusRoundState() + AddReceivedHeader(headerHandler data.HeaderHandler) + GetReceivedHeaders() []data.HeaderHandler + AddMessageWithSignature(key string, message p2p.MessageP2P) + GetMessageWithSignature(key string) (p2p.MessageP2P, bool) + IsNodeLeaderInCurrentRound(node string) bool + GetLeader() (string, error) + GetNextConsensusGroup( + randomSource []byte, + round uint64, + shardId uint32, + nodesCoordinator nodesCoordinator.NodesCoordinator, + epoch uint32, + ) (string, []string, error) + IsConsensusDataSet() bool + IsConsensusDataEqual(data []byte) bool + IsJobDone(node string, currentSubroundId int) bool + IsSubroundFinished(subroundID int) bool + IsNodeSelf(node string) bool + IsBlockBodyAlreadyReceived() bool + IsHeaderAlreadyReceived() bool + CanDoSubroundJob(currentSubroundId int) bool + CanProcessReceivedMessage(cnsDta *consensus.Message, currentRoundIndex int64, currentSubroundId int) bool + GenerateBitmap(subroundId int) []byte + ProcessingBlock() bool + SetProcessingBlock(processingBlock bool) + GetData() []byte + SetData(data []byte) + IsMultiKeyLeaderInCurrentRound() bool + IsLeaderJobDone(currentSubroundId int) bool + IsMultiKeyJobDone(currentSubroundId int) bool + IsSelfJobDone(currentSubroundID int) bool + GetMultikeyRedundancyStepInReason() string + ResetRoundsWithoutReceivedMessages(pkBytes []byte, pid core.PeerID) + GetRoundCanceled() bool + SetRoundCanceled(state bool) + GetRoundIndex() int64 + SetRoundIndex(roundIndex int64) + GetRoundTimeStamp() time.Time + SetRoundTimeStamp(roundTimeStamp time.Time) + GetExtendedCalled() bool + GetBody() data.BodyHandler + SetBody(body data.BodyHandler) + GetHeader() data.HeaderHandler + SetHeader(header data.HeaderHandler) + GetWaitingAllSignaturesTimeOut() bool + SetWaitingAllSignaturesTimeOut(bool) + RoundConsensusHandler + RoundStatusHandler + RoundThresholdHandler + IsInterfaceNil() bool +} + +// RoundConsensusHandler encapsulates the methods needed for a consensus round +type RoundConsensusHandler interface { + ConsensusGroupIndex(pubKey string) (int, error) + SelfConsensusGroupIndex() (int, error) + SetEligibleList(eligibleList map[string]struct{}) + ConsensusGroup() []string + SetConsensusGroup(consensusGroup []string) + SetLeader(leader string) + ConsensusGroupSize() int + SetConsensusGroupSize(consensusGroupSize int) + SelfPubKey() string + SetSelfPubKey(selfPubKey string) + JobDone(key string, subroundId int) (bool, error) + SetJobDone(key string, subroundId int, value bool) error + SelfJobDone(subroundId int) (bool, error) + IsNodeInConsensusGroup(node string) bool + IsNodeInEligibleList(node string) bool + ComputeSize(subroundId int) int + ResetRoundState() + IsMultiKeyInConsensusGroup() bool + IsKeyManagedBySelf(pkBytes []byte) bool + IncrementRoundsWithoutReceivedMessages(pkBytes []byte) + GetKeysHandler() consensus.KeysHandler + Leader() string +} + +// RoundStatusHandler encapsulates the methods needed for the status of a subround +type RoundStatusHandler interface { + Status(subroundId int) SubroundStatus + SetStatus(subroundId int, subroundStatus SubroundStatus) + ResetRoundStatus() +} + +// RoundThresholdHandler encapsulates the methods needed for the round consensus threshold +type RoundThresholdHandler interface { + Threshold(subroundId int) int + SetThreshold(subroundId int, threshold int) + FallbackThreshold(subroundId int) int + SetFallbackThreshold(subroundId int, threshold int) +} + +// InvalidSignersCache encapsulates the methods needed for a invalid signers cache +type InvalidSignersCache interface { + AddInvalidSigners(headerHash []byte, invalidSigners []byte, invalidPublicKeys []string) + CheckKnownInvalidSigners(headerHash []byte, invalidSigners []byte) bool + Reset() + IsInterfaceNil() bool +} diff --git a/consensus/spos/invalidSignersCache.go b/consensus/spos/invalidSignersCache.go new file mode 100644 index 00000000000..ca26867f49d --- /dev/null +++ b/consensus/spos/invalidSignersCache.go @@ -0,0 +1,130 @@ +package spos + +import ( + "sync" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/p2p" +) + +// ArgInvalidSignersCache defines the DTO used to create a new instance of invalidSignersCache +type ArgInvalidSignersCache struct { + Hasher hashing.Hasher + SigningHandler p2p.P2PSigningHandler + Marshaller marshal.Marshalizer +} + +type invalidSignersCache struct { + sync.RWMutex + invalidSignersHashesMap map[string]struct{} + invalidSignersForHeaderMap map[string]map[string]struct{} + hasher hashing.Hasher + signingHandler p2p.P2PSigningHandler + marshaller marshal.Marshalizer +} + +// NewInvalidSignersCache returns a new instance of invalidSignersCache +func NewInvalidSignersCache(args ArgInvalidSignersCache) (*invalidSignersCache, error) { + err := checkArgs(args) + if err != nil { + return nil, err + } + + return &invalidSignersCache{ + invalidSignersHashesMap: make(map[string]struct{}), + invalidSignersForHeaderMap: make(map[string]map[string]struct{}), + hasher: args.Hasher, + signingHandler: args.SigningHandler, + marshaller: args.Marshaller, + }, nil +} + +func checkArgs(args ArgInvalidSignersCache) error { + if check.IfNil(args.Hasher) { + return ErrNilHasher + } + if check.IfNil(args.SigningHandler) { + return ErrNilSigningHandler + } + if check.IfNil(args.Marshaller) { + return ErrNilMarshalizer + } + + return nil +} + +// AddInvalidSigners adds the provided hash into the internal map if it does not exist +func (cache *invalidSignersCache) AddInvalidSigners(headerHash []byte, invalidSigners []byte, invalidPublicKeys []string) { + if len(invalidPublicKeys) == 0 || len(invalidSigners) == 0 { + return + } + + cache.Lock() + defer cache.Unlock() + + invalidSignersHash := cache.hasher.Compute(string(invalidSigners)) + cache.invalidSignersHashesMap[string(invalidSignersHash)] = struct{}{} + + _, ok := cache.invalidSignersForHeaderMap[string(headerHash)] + if !ok { + cache.invalidSignersForHeaderMap[string(headerHash)] = make(map[string]struct{}) + } + + for _, pk := range invalidPublicKeys { + cache.invalidSignersForHeaderMap[string(headerHash)][pk] = struct{}{} + } +} + +// CheckKnownInvalidSigners checks whether all the provided invalid signers are known for the header hash +func (cache *invalidSignersCache) CheckKnownInvalidSigners(headerHash []byte, serializedInvalidSigners []byte) bool { + cache.RLock() + defer cache.RUnlock() + + invalidSignersHash := cache.hasher.Compute(string(serializedInvalidSigners)) + _, hasSameInvalidSigners := cache.invalidSignersHashesMap[string(invalidSignersHash)] + if hasSameInvalidSigners { + return true + } + + _, isHeaderKnown := cache.invalidSignersForHeaderMap[string(headerHash)] + if !isHeaderKnown { + return false + } + + invalidSignersP2PMessages, err := cache.signingHandler.Deserialize(serializedInvalidSigners) + if err != nil { + return false + } + + for _, msg := range invalidSignersP2PMessages { + cnsMsg := &consensus.Message{} + err = cache.marshaller.Unmarshal(cnsMsg, msg.Data()) + if err != nil { + return false + } + + _, isKnownInvalidSigner := cache.invalidSignersForHeaderMap[string(headerHash)][string(cnsMsg.PubKey)] + if !isKnownInvalidSigner { + return false + } + } + + return true +} + +// Reset clears the internal maps +func (cache *invalidSignersCache) Reset() { + cache.Lock() + defer cache.Unlock() + + cache.invalidSignersHashesMap = make(map[string]struct{}) + cache.invalidSignersForHeaderMap = make(map[string]map[string]struct{}) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (cache *invalidSignersCache) IsInterfaceNil() bool { + return cache == nil +} diff --git a/consensus/spos/invalidSignersCache_test.go b/consensus/spos/invalidSignersCache_test.go new file mode 100644 index 00000000000..354a5b9306c --- /dev/null +++ b/consensus/spos/invalidSignersCache_test.go @@ -0,0 +1,201 @@ +package spos + +import ( + "crypto/rand" + "fmt" + "sync" + "testing" + "time" + + pubsub "github.com/libp2p/go-libp2p-pubsub" + pb "github.com/libp2p/go-libp2p-pubsub/pb" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/multiversx/mx-chain-communication-go/p2p/data" + "github.com/multiversx/mx-chain-communication-go/p2p/libp2p" + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/consensus" + consensusMock "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/p2p" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/stretchr/testify/require" +) + +func createMockArgs() ArgInvalidSignersCache { + return ArgInvalidSignersCache{ + Hasher: &testscommon.HasherStub{}, + SigningHandler: &consensusMock.MessageSigningHandlerStub{}, + Marshaller: &marshallerMock.MarshalizerStub{}, + } +} + +func TestInvalidSignersCache_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var cache *invalidSignersCache + require.True(t, cache.IsInterfaceNil()) + + cache, _ = NewInvalidSignersCache(createMockArgs()) + require.False(t, cache.IsInterfaceNil()) +} + +func TestNewInvalidSignersCache(t *testing.T) { + t.Parallel() + + t.Run("nil Hasher should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgs() + args.Hasher = nil + + cache, err := NewInvalidSignersCache(args) + require.Equal(t, ErrNilHasher, err) + require.Nil(t, cache) + }) + t.Run("nil SigningHandler should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgs() + args.SigningHandler = nil + + cache, err := NewInvalidSignersCache(args) + require.Equal(t, ErrNilSigningHandler, err) + require.Nil(t, cache) + }) + t.Run("nil Marshaller should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgs() + args.Marshaller = nil + + cache, err := NewInvalidSignersCache(args) + require.Equal(t, ErrNilMarshalizer, err) + require.Nil(t, cache) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + cache, err := NewInvalidSignersCache(createMockArgs()) + require.NoError(t, err) + require.NotNil(t, cache) + }) +} + +func TestInvalidSignersCache(t *testing.T) { + t.Parallel() + + t.Run("all ops should work", func(t *testing.T) { + t.Parallel() + + headerHash1 := []byte("headerHash11") + invalidSigners1 := []byte("invalidSigners1") + pubKeys1 := []string{"pk0", "pk1"} + invalidSigners2 := []byte("invalidSigners2") + + args := createMockArgs() + args.Hasher = &testscommon.HasherStub{ + ComputeCalled: func(s string) []byte { + return []byte(s) + }, + } + args.SigningHandler = &consensusMock.MessageSigningHandlerStub{ + DeserializeCalled: func(messagesBytes []byte) ([]p2p.MessageP2P, error) { + if string(messagesBytes) == string(invalidSigners1) { + m1, _ := libp2p.NewMessage(createDummyP2PMessage(), &testscommon.ProtoMarshalizerMock{}, "") + m2, _ := libp2p.NewMessage(createDummyP2PMessage(), &testscommon.ProtoMarshalizerMock{}, "") + return []p2p.MessageP2P{m1, m2}, nil + } + + m1, _ := libp2p.NewMessage(createDummyP2PMessage(), &testscommon.ProtoMarshalizerMock{}, "") + return []p2p.MessageP2P{m1}, nil + }, + } + cnt := 0 + args.Marshaller = &marshallerMock.MarshalizerStub{ + UnmarshalCalled: func(obj interface{}, buff []byte) error { + message := obj.(*consensus.Message) + message.PubKey = []byte(fmt.Sprintf("pk%d", cnt)) + cnt++ + + return nil + }, + } + cache, _ := NewInvalidSignersCache(args) + require.NotNil(t, cache) + + cache.AddInvalidSigners(nil, nil, nil) // early return, for coverage only + + require.False(t, cache.CheckKnownInvalidSigners(headerHash1, invalidSigners1)) + + cache.AddInvalidSigners(headerHash1, invalidSigners1, pubKeys1) + require.True(t, cache.CheckKnownInvalidSigners(headerHash1, invalidSigners1)) // should find in signers by hashes map + + require.True(t, cache.CheckKnownInvalidSigners(headerHash1, invalidSigners2)) // should have different hash but the known signers + + cache.Reset() + require.False(t, cache.CheckKnownInvalidSigners(headerHash1, invalidSigners1)) + }) + t.Run("concurrent ops should work", func(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r != nil { + require.Fail(t, "should have not panicked") + } + }() + + args := createMockArgs() + cache, _ := NewInvalidSignersCache(args) + require.NotNil(t, cache) + + numCalls := 1000 + wg := sync.WaitGroup{} + wg.Add(numCalls) + + for i := 0; i < numCalls; i++ { + go func(idx int) { + switch idx % 3 { + case 0: + cache.AddInvalidSigners([]byte("hash"), []byte("invalidSigners"), []string{"pk0", "pk1"}) + case 1: + cache.CheckKnownInvalidSigners([]byte("hash"), []byte("invalidSigners")) + case 2: + cache.Reset() + default: + require.Fail(t, "should not happen") + } + + wg.Done() + }(i) + } + + wg.Wait() + }) +} + +func createDummyP2PMessage() *pubsub.Message { + marshaller := &testscommon.ProtoMarshalizerMock{} + topicMessage := &data.TopicMessage{ + Timestamp: time.Now().Unix(), + Payload: []byte("data"), + Version: 1, + } + buff, _ := marshaller.Marshal(topicMessage) + topic := "topic" + mes := &pb.Message{ + From: getRandomID().Bytes(), + Data: buff, + Topic: &topic, + } + + return &pubsub.Message{Message: mes} +} + +func getRandomID() core.PeerID { + prvKey, _, _ := crypto.GenerateSecp256k1Key(rand.Reader) + id, _ := peer.IDFromPublicKey(prvKey.GetPublic()) + + return core.PeerID(id) +} diff --git a/consensus/spos/roundConsensus.go b/consensus/spos/roundConsensus.go index b230e124a15..434adc6258d 100644 --- a/consensus/spos/roundConsensus.go +++ b/consensus/spos/roundConsensus.go @@ -4,6 +4,7 @@ import ( "sync" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/consensus" ) @@ -12,6 +13,7 @@ type roundConsensus struct { eligibleNodes map[string]struct{} mutEligible sync.RWMutex consensusGroup []string + leader string consensusGroupSize int selfPubKey string validatorRoundStates map[string]*roundState @@ -64,15 +66,18 @@ func (rcns *roundConsensus) SetEligibleList(eligibleList map[string]struct{}) { // ConsensusGroup returns the consensus group ID's func (rcns *roundConsensus) ConsensusGroup() []string { + rcns.mut.RLock() + defer rcns.mut.RUnlock() + return rcns.consensusGroup } // SetConsensusGroup sets the consensus group ID's func (rcns *roundConsensus) SetConsensusGroup(consensusGroup []string) { - rcns.consensusGroup = consensusGroup - rcns.mut.Lock() + rcns.consensusGroup = consensusGroup + rcns.validatorRoundStates = make(map[string]*roundState) for i := 0; i < len(consensusGroup); i++ { @@ -82,14 +87,30 @@ func (rcns *roundConsensus) SetConsensusGroup(consensusGroup []string) { rcns.mut.Unlock() } +// Leader returns the leader for the current consensus +func (rcns *roundConsensus) Leader() string { + rcns.mut.RLock() + defer rcns.mut.RUnlock() + + return rcns.leader +} + +// SetLeader sets the leader for the current consensus +func (rcns *roundConsensus) SetLeader(leader string) { + rcns.mut.Lock() + defer rcns.mut.Unlock() + + rcns.leader = leader +} + // ConsensusGroupSize returns the consensus group size func (rcns *roundConsensus) ConsensusGroupSize() int { return rcns.consensusGroupSize } // SetConsensusGroupSize sets the consensus group size -func (rcns *roundConsensus) SetConsensusGroupSize(consensusGroudpSize int) { - rcns.consensusGroupSize = consensusGroudpSize +func (rcns *roundConsensus) SetConsensusGroupSize(consensusGroupSize int) { + rcns.consensusGroupSize = consensusGroupSize } // SelfPubKey returns selfPubKey ID @@ -144,6 +165,9 @@ func (rcns *roundConsensus) SelfJobDone(subroundId int) (bool, error) { // IsNodeInConsensusGroup method checks if the node is part of consensus group of the current round func (rcns *roundConsensus) IsNodeInConsensusGroup(node string) bool { + rcns.mut.RLock() + defer rcns.mut.RUnlock() + for i := 0; i < len(rcns.consensusGroup); i++ { if rcns.consensusGroup[i] == node { return true @@ -205,7 +229,7 @@ func (rcns *roundConsensus) ResetRoundState() { // is in consensus group in the current round func (rcns *roundConsensus) IsMultiKeyInConsensusGroup() bool { for i := 0; i < len(rcns.consensusGroup); i++ { - if rcns.IsKeyManagedByCurrentNode([]byte(rcns.consensusGroup[i])) { + if rcns.IsKeyManagedBySelf([]byte(rcns.consensusGroup[i])) { return true } } @@ -213,8 +237,8 @@ func (rcns *roundConsensus) IsMultiKeyInConsensusGroup() bool { return false } -// IsKeyManagedByCurrentNode returns true if the key is managed by the current node -func (rcns *roundConsensus) IsKeyManagedByCurrentNode(pkBytes []byte) bool { +// IsKeyManagedBySelf returns true if the key is managed by the current node +func (rcns *roundConsensus) IsKeyManagedBySelf(pkBytes []byte) bool { return rcns.keysHandler.IsKeyManagedByCurrentNode(pkBytes) } @@ -222,3 +246,8 @@ func (rcns *roundConsensus) IsKeyManagedByCurrentNode(pkBytes []byte) bool { func (rcns *roundConsensus) IncrementRoundsWithoutReceivedMessages(pkBytes []byte) { rcns.keysHandler.IncrementRoundsWithoutReceivedMessages(pkBytes) } + +// GetKeysHandler returns the keysHandler instance +func (rcns *roundConsensus) GetKeysHandler() consensus.KeysHandler { + return rcns.keysHandler +} diff --git a/consensus/spos/roundConsensus_test.go b/consensus/spos/roundConsensus_test.go index 4ba8f7e47fe..36c8e5ad8ab 100644 --- a/consensus/spos/roundConsensus_test.go +++ b/consensus/spos/roundConsensus_test.go @@ -296,23 +296,6 @@ func TestRoundConsensus_IsMultiKeyInConsensusGroup(t *testing.T) { }) } -func TestRoundConsensus_IsKeyManagedByCurrentNode(t *testing.T) { - t.Parallel() - - managedPkBytes := []byte("managed pk bytes") - wasCalled := false - keysHandler := &testscommon.KeysHandlerStub{ - IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { - assert.Equal(t, managedPkBytes, pkBytes) - wasCalled = true - return true - }, - } - roundConsensus := initRoundConsensusWithKeysHandler(keysHandler) - assert.True(t, roundConsensus.IsKeyManagedByCurrentNode(managedPkBytes)) - assert.True(t, wasCalled) -} - func TestRoundConsensus_IncrementRoundsWithoutReceivedMessages(t *testing.T) { t.Parallel() diff --git a/consensus/spos/roundStatus.go b/consensus/spos/roundStatus.go index 8517396904a..7d3b67fdc15 100644 --- a/consensus/spos/roundStatus.go +++ b/consensus/spos/roundStatus.go @@ -5,7 +5,7 @@ import ( ) // SubroundStatus defines the type used to refer the state of the current subround -type SubroundStatus int +type SubroundStatus = int const ( // SsNotFinished defines the un-finished state of the subround diff --git a/consensus/spos/scheduledProcessor_test.go b/consensus/spos/scheduledProcessor_test.go index 7316209921b..0ec6ca3d5b2 100644 --- a/consensus/spos/scheduledProcessor_test.go +++ b/consensus/spos/scheduledProcessor_test.go @@ -1,4 +1,4 @@ -package spos +package spos_test import ( "errors" @@ -8,46 +8,49 @@ import ( "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" - "github.com/multiversx/mx-chain-go/consensus/mock" + + "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/stretchr/testify/require" ) func TestProcessingStatus_String(t *testing.T) { t.Parallel() - require.Equal(t, processingNotStartedString, processingNotStarted.String()) - require.Equal(t, processingErrorString, processingError.String()) - require.Equal(t, inProgressString, inProgress.String()) - require.Equal(t, processingOKString, processingOK.String()) - require.Equal(t, stoppedString, stopped.String()) + require.Equal(t, spos.ProcessingNotStartedString, spos.ProcessingNotStarted.String()) + require.Equal(t, spos.ProcessingErrorString, spos.ProcessingError.String()) + require.Equal(t, spos.InProgressString, spos.InProgress.String()) + require.Equal(t, spos.ProcessingOKString, spos.ProcessingOK.String()) + require.Equal(t, spos.StoppedString, spos.Stopped.String()) } func TestNewScheduledProcessorWrapper_NilSyncTimerShouldErr(t *testing.T) { t.Parallel() - args := ScheduledProcessorWrapperArgs{ + args := spos.ScheduledProcessorWrapperArgs{ SyncTimer: nil, Processor: &testscommon.BlockProcessorStub{}, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } - sp, err := NewScheduledProcessorWrapper(args) + sp, err := spos.NewScheduledProcessorWrapper(args) require.Nil(t, sp) - require.Equal(t, ErrNilSyncTimer, err) + require.Equal(t, spos.ErrNilSyncTimer, err) } func TestNewScheduledProcessorWrapper_NilBlockProcessorShouldErr(t *testing.T) { t.Parallel() - args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + args := spos.ScheduledProcessorWrapperArgs{ + SyncTimer: &consensus.SyncTimerMock{}, Processor: nil, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } - sp, err := NewScheduledProcessorWrapper(args) + sp, err := spos.NewScheduledProcessorWrapper(args) require.Nil(t, sp) require.Equal(t, process.ErrNilBlockProcessor, err) } @@ -55,13 +58,13 @@ func TestNewScheduledProcessorWrapper_NilBlockProcessorShouldErr(t *testing.T) { func TestNewScheduledProcessorWrapper_NilRoundTimeDurationHandlerShouldErr(t *testing.T) { t.Parallel() - args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + args := spos.ScheduledProcessorWrapperArgs{ + SyncTimer: &consensus.SyncTimerMock{}, Processor: &testscommon.BlockProcessorStub{}, RoundTimeDurationHandler: nil, } - sp, err := NewScheduledProcessorWrapper(args) + sp, err := spos.NewScheduledProcessorWrapper(args) require.Nil(t, sp) require.Equal(t, process.ErrNilRoundTimeDurationHandler, err) } @@ -69,13 +72,13 @@ func TestNewScheduledProcessorWrapper_NilRoundTimeDurationHandlerShouldErr(t *te func TestNewScheduledProcessorWrapper_NilBlockProcessorOK(t *testing.T) { t.Parallel() - args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + args := spos.ScheduledProcessorWrapperArgs{ + SyncTimer: &consensus.SyncTimerMock{}, Processor: &testscommon.BlockProcessorStub{}, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } - sp, err := NewScheduledProcessorWrapper(args) + sp, err := spos.NewScheduledProcessorWrapper(args) require.Nil(t, err) require.NotNil(t, sp) } @@ -84,41 +87,41 @@ func TestScheduledProcessorWrapper_IsProcessedOKEarlyExit(t *testing.T) { t.Parallel() called := atomic.Flag{} - args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{ + args := spos.ScheduledProcessorWrapperArgs{ + SyncTimer: &consensus.SyncTimerMock{ CurrentTimeCalled: func() time.Time { called.SetValue(true) return time.Now() }, }, Processor: &testscommon.BlockProcessorStub{}, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } - sp, err := NewScheduledProcessorWrapper(args) + sp, err := spos.NewScheduledProcessorWrapper(args) require.Nil(t, err) require.False(t, sp.IsProcessedOKWithTimeout()) require.False(t, called.IsSet()) - sp.setStatus(processingOK) + sp.SetStatus(spos.ProcessingOK) require.True(t, sp.IsProcessedOKWithTimeout()) require.False(t, called.IsSet()) - sp.setStatus(processingError) + sp.SetStatus(spos.ProcessingError) require.False(t, sp.IsProcessedOKWithTimeout()) require.False(t, called.IsSet()) } -func defaultScheduledProcessorWrapperArgs() ScheduledProcessorWrapperArgs { - return ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{ +func defaultScheduledProcessorWrapperArgs() spos.ScheduledProcessorWrapperArgs { + return spos.ScheduledProcessorWrapperArgs{ + SyncTimer: &consensus.SyncTimerMock{ CurrentTimeCalled: func() time.Time { return time.Now() }, }, Processor: &testscommon.BlockProcessorStub{}, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } } @@ -126,30 +129,30 @@ func TestScheduledProcessorWrapper_IsProcessedInProgressNegativeRemainingTime(t t.Parallel() args := defaultScheduledProcessorWrapperArgs() - sp, err := NewScheduledProcessorWrapper(args) + sp, err := spos.NewScheduledProcessorWrapper(args) require.Nil(t, err) - sp.setStatus(inProgress) + sp.SetStatus(spos.InProgress) require.False(t, sp.IsProcessedOKWithTimeout()) startTime := time.Now() - sp.startTime = startTime.Add(-200 * time.Millisecond) + sp.SetStartTime(startTime.Add(-200 * time.Millisecond)) require.False(t, sp.IsProcessedOKWithTimeout()) endTime := time.Now() timeSpent := endTime.Sub(startTime) - require.Less(t, timeSpent, sp.roundTimeDurationHandler.TimeDuration()) + require.Less(t, timeSpent, sp.GetRoundTimeHandler().TimeDuration()) } func TestScheduledProcessorWrapper_IsProcessedInProgressStartingInFuture(t *testing.T) { t.Parallel() args := defaultScheduledProcessorWrapperArgs() - sp, err := NewScheduledProcessorWrapper(args) + sp, err := spos.NewScheduledProcessorWrapper(args) require.Nil(t, err) - sp.setStatus(inProgress) + sp.SetStatus(spos.InProgress) startTime := time.Now() - sp.startTime = startTime.Add(500 * time.Millisecond) + sp.SetStartTime(startTime.Add(500 * time.Millisecond)) require.False(t, sp.IsProcessedOKWithTimeout()) endTime := time.Now() require.Less(t, endTime.Sub(startTime), time.Millisecond*100) @@ -159,172 +162,172 @@ func TestScheduledProcessorWrapper_IsProcessedInProgressEarlyCompletion(t *testi t.Parallel() args := defaultScheduledProcessorWrapperArgs() - sp, err := NewScheduledProcessorWrapper(args) + sp, err := spos.NewScheduledProcessorWrapper(args) require.Nil(t, err) - sp.setStatus(inProgress) - sp.startTime = time.Now() + sp.SetStatus(spos.InProgress) + sp.SetStartTime(time.Now()) go func() { time.Sleep(10 * time.Millisecond) - sp.setStatus(processingOK) + sp.SetStatus(spos.ProcessingOK) }() require.True(t, sp.IsProcessedOKWithTimeout()) endTime := time.Now() - timeSpent := endTime.Sub(sp.startTime) - require.Less(t, timeSpent, sp.roundTimeDurationHandler.TimeDuration()) + timeSpent := endTime.Sub(sp.GetStartTime()) + require.Less(t, timeSpent, sp.GetRoundTimeHandler().TimeDuration()) } func TestScheduledProcessorWrapper_IsProcessedInProgressEarlyCompletionWithError(t *testing.T) { t.Parallel() args := defaultScheduledProcessorWrapperArgs() - sp, err := NewScheduledProcessorWrapper(args) + sp, err := spos.NewScheduledProcessorWrapper(args) require.Nil(t, err) - sp.setStatus(inProgress) - sp.startTime = time.Now() + sp.SetStatus(spos.InProgress) + sp.SetStartTime(time.Now()) go func() { time.Sleep(10 * time.Millisecond) - sp.setStatus(processingError) + sp.SetStatus(spos.ProcessingError) }() require.False(t, sp.IsProcessedOKWithTimeout()) endTime := time.Now() - timeSpent := endTime.Sub(sp.startTime) - require.Less(t, timeSpent, sp.roundTimeDurationHandler.TimeDuration()) + timeSpent := endTime.Sub(sp.GetStartTime()) + require.Less(t, timeSpent, sp.GetRoundTimeHandler().TimeDuration()) } func TestScheduledProcessorWrapper_IsProcessedInProgressAlreadyStartedNoCompletion(t *testing.T) { t.Parallel() args := defaultScheduledProcessorWrapperArgs() - sp, err := NewScheduledProcessorWrapper(args) + sp, err := spos.NewScheduledProcessorWrapper(args) require.Nil(t, err) - sp.setStatus(inProgress) + sp.SetStatus(spos.InProgress) startTime := time.Now() - sp.startTime = startTime.Add(-10 * time.Millisecond) + sp.SetStartTime(startTime.Add(-10 * time.Millisecond)) require.False(t, sp.IsProcessedOKWithTimeout()) endTime := time.Now() - require.Less(t, endTime.Sub(startTime), sp.roundTimeDurationHandler.TimeDuration()) - require.Greater(t, endTime.Sub(startTime), sp.roundTimeDurationHandler.TimeDuration()-10*time.Millisecond) + require.Less(t, endTime.Sub(startTime), sp.GetRoundTimeHandler().TimeDuration()) + require.Greater(t, endTime.Sub(startTime), sp.GetRoundTimeHandler().TimeDuration()-10*time.Millisecond) } func TestScheduledProcessorWrapper_IsProcessedInProgressTimeout(t *testing.T) { t.Parallel() args := defaultScheduledProcessorWrapperArgs() - sp, err := NewScheduledProcessorWrapper(args) + sp, err := spos.NewScheduledProcessorWrapper(args) require.Nil(t, err) - sp.setStatus(inProgress) - sp.startTime = time.Now() + sp.SetStatus(spos.InProgress) + sp.SetStartTime(time.Now()) require.False(t, sp.IsProcessedOKWithTimeout()) endTime := time.Now() - require.Greater(t, endTime.Sub(sp.startTime), sp.roundTimeDurationHandler.TimeDuration()) + require.Greater(t, endTime.Sub(sp.GetStartTime()), sp.GetRoundTimeHandler().TimeDuration()) } func TestScheduledProcessorWrapper_StatusGetterAndSetter(t *testing.T) { t.Parallel() - args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + args := spos.ScheduledProcessorWrapperArgs{ + SyncTimer: &consensus.SyncTimerMock{}, Processor: &testscommon.BlockProcessorStub{}, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } - sp, _ := NewScheduledProcessorWrapper(args) - require.Equal(t, processingNotStarted, sp.getStatus()) + sp, _ := spos.NewScheduledProcessorWrapper(args) + require.Equal(t, spos.ProcessingNotStarted, sp.GetStatus()) - sp.setStatus(processingOK) - require.Equal(t, processingOK, sp.getStatus()) + sp.SetStatus(spos.ProcessingOK) + require.Equal(t, spos.ProcessingOK, sp.GetStatus()) - sp.setStatus(inProgress) - require.Equal(t, inProgress, sp.getStatus()) + sp.SetStatus(spos.InProgress) + require.Equal(t, spos.InProgress, sp.GetStatus()) - sp.setStatus(processingError) - require.Equal(t, processingError, sp.getStatus()) + sp.SetStatus(spos.ProcessingError) + require.Equal(t, spos.ProcessingError, sp.GetStatus()) } func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV1ProcessingOK(t *testing.T) { t.Parallel() processScheduledCalled := atomic.Flag{} - args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + args := spos.ScheduledProcessorWrapperArgs{ + SyncTimer: &consensus.SyncTimerMock{}, Processor: &testscommon.BlockProcessorStub{ ProcessScheduledBlockCalled: func(header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration) error { processScheduledCalled.SetValue(true) return nil }, }, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } - sp, _ := NewScheduledProcessorWrapper(args) - require.Equal(t, processingNotStarted, sp.getStatus()) + sp, _ := spos.NewScheduledProcessorWrapper(args) + require.Equal(t, spos.ProcessingNotStarted, sp.GetStatus()) header := &block.Header{} body := &block.Body{} sp.StartScheduledProcessing(header, body, time.Now()) time.Sleep(10 * time.Millisecond) require.False(t, processScheduledCalled.IsSet()) - require.Equal(t, processingOK, sp.getStatus()) + require.Equal(t, spos.ProcessingOK, sp.GetStatus()) } func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ProcessingWithError(t *testing.T) { t.Parallel() processScheduledCalled := atomic.Flag{} - args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + args := spos.ScheduledProcessorWrapperArgs{ + SyncTimer: &consensus.SyncTimerMock{}, Processor: &testscommon.BlockProcessorStub{ ProcessScheduledBlockCalled: func(header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration) error { processScheduledCalled.SetValue(true) return errors.New("processing error") }, }, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } - sp, _ := NewScheduledProcessorWrapper(args) - require.Equal(t, processingNotStarted, sp.getStatus()) + sp, _ := spos.NewScheduledProcessorWrapper(args) + require.Equal(t, spos.ProcessingNotStarted, sp.GetStatus()) header := &block.HeaderV2{} body := &block.Body{} sp.StartScheduledProcessing(header, body, time.Now()) - require.Equal(t, inProgress, sp.getStatus()) + require.Equal(t, spos.InProgress, sp.GetStatus()) time.Sleep(100 * time.Millisecond) require.True(t, processScheduledCalled.IsSet()) - require.Equal(t, processingError, sp.getStatus()) + require.Equal(t, spos.ProcessingError, sp.GetStatus()) } func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ProcessingOK(t *testing.T) { t.Parallel() processScheduledCalled := atomic.Flag{} - args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{}, + args := spos.ScheduledProcessorWrapperArgs{ + SyncTimer: &consensus.SyncTimerMock{}, Processor: &testscommon.BlockProcessorStub{ ProcessScheduledBlockCalled: func(header data.HeaderHandler, body data.BodyHandler, haveTime func() time.Duration) error { processScheduledCalled.SetValue(true) return nil }, }, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } - sp, _ := NewScheduledProcessorWrapper(args) - require.Equal(t, processingNotStarted, sp.getStatus()) + sp, _ := spos.NewScheduledProcessorWrapper(args) + require.Equal(t, spos.ProcessingNotStarted, sp.GetStatus()) header := &block.HeaderV2{} body := &block.Body{} sp.StartScheduledProcessing(header, body, time.Now()) - require.Equal(t, inProgress, sp.getStatus()) + require.Equal(t, spos.InProgress, sp.GetStatus()) - time.Sleep(100 * time.Millisecond) + time.Sleep(200 * time.Millisecond) require.True(t, processScheduledCalled.IsSet()) - require.Equal(t, processingOK, sp.getStatus()) + require.Equal(t, spos.ProcessingOK, sp.GetStatus()) } func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ForceStopped(t *testing.T) { @@ -332,8 +335,8 @@ func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ForceStopped( processScheduledCalled := atomic.Flag{} - args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{ + args := spos.ScheduledProcessorWrapperArgs{ + SyncTimer: &consensus.SyncTimerMock{ CurrentTimeCalled: func() time.Time { return time.Now() }, @@ -350,10 +353,10 @@ func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ForceStopped( } }, }, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } - spw, err := NewScheduledProcessorWrapper(args) + spw, err := spos.NewScheduledProcessorWrapper(args) require.Nil(t, err) hdr := &block.HeaderV2{} @@ -363,9 +366,9 @@ func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ForceStopped( startTime := time.Now() spw.ForceStopScheduledExecutionBlocking() endTime := time.Now() - status := spw.getStatus() + status := spw.GetStatus() require.True(t, processScheduledCalled.IsSet()) - require.Equal(t, stopped, status, status.String()) + require.Equal(t, spos.Stopped, status, status.String()) require.Less(t, 10*time.Millisecond, endTime.Sub(startTime)) } @@ -373,8 +376,8 @@ func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ForceStopAfte t.Parallel() processScheduledCalled := atomic.Flag{} - args := ScheduledProcessorWrapperArgs{ - SyncTimer: &mock.SyncTimerMock{ + args := spos.ScheduledProcessorWrapperArgs{ + SyncTimer: &consensus.SyncTimerMock{ CurrentTimeCalled: func() time.Time { return time.Now() }, @@ -386,10 +389,10 @@ func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ForceStopAfte return nil }, }, - RoundTimeDurationHandler: &mock.RoundHandlerMock{}, + RoundTimeDurationHandler: &consensus.RoundHandlerMock{}, } - spw, err := NewScheduledProcessorWrapper(args) + spw, err := spos.NewScheduledProcessorWrapper(args) require.Nil(t, err) hdr := &block.HeaderV2{} @@ -397,7 +400,7 @@ func TestScheduledProcessorWrapper_StartScheduledProcessingHeaderV2ForceStopAfte spw.StartScheduledProcessing(hdr, blkBody, time.Now()) time.Sleep(200 * time.Millisecond) spw.ForceStopScheduledExecutionBlocking() - status := spw.getStatus() + status := spw.GetStatus() require.True(t, processScheduledCalled.IsSet()) - require.Equal(t, processingOK, status, status.String()) + require.Equal(t, spos.ProcessingOK, status, status.String()) } diff --git a/consensus/spos/sposFactory/sposFactory.go b/consensus/spos/sposFactory/sposFactory.go index 84faafe53e6..99f0cf682eb 100644 --- a/consensus/spos/sposFactory/sposFactory.go +++ b/consensus/spos/sposFactory/sposFactory.go @@ -6,50 +6,15 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/broadcast" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" - "github.com/multiversx/mx-chain-go/outport" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" ) -// GetSubroundsFactory returns a subrounds factory depending on the given parameter -func GetSubroundsFactory( - consensusDataContainer spos.ConsensusCoreHandler, - consensusState *spos.ConsensusState, - worker spos.WorkerHandler, - consensusType string, - appStatusHandler core.AppStatusHandler, - outportHandler outport.OutportHandler, - sentSignatureTracker spos.SentSignaturesTracker, - chainID []byte, - currentPid core.PeerID, -) (spos.SubroundsFactory, error) { - switch consensusType { - case blsConsensusType: - subRoundFactoryBls, err := bls.NewSubroundsFactory( - consensusDataContainer, - consensusState, - worker, - chainID, - currentPid, - appStatusHandler, - sentSignatureTracker, - ) - if err != nil { - return nil, err - } - - subRoundFactoryBls.SetOutportHandler(outportHandler) - - return subRoundFactoryBls, nil - default: - return nil, ErrInvalidConsensusType - } -} - // GetConsensusCoreFactory returns a consensus service depending on the given parameter func GetConsensusCoreFactory(consensusType string) (spos.ConsensusService, error) { switch consensusType { @@ -77,6 +42,20 @@ func GetBroadcastMessenger( return nil, spos.ErrNilShardCoordinator } + dbbArgs := &broadcast.ArgsDelayedBlockBroadcaster{ + InterceptorsContainer: interceptorsContainer, + HeadersSubscriber: headersSubscriber, + ShardCoordinator: shardCoordinator, + LeaderCacheSize: maxDelayCacheSize, + ValidatorCacheSize: maxDelayCacheSize, + AlarmScheduler: alarmScheduler, + } + + delayedBroadcaster, err := broadcast.NewDelayedBlockBroadcaster(dbbArgs) + if err != nil { + return nil, err + } + commonMessengerArgs := broadcast.CommonMessengerArgs{ Marshalizer: marshalizer, Hasher: hasher, @@ -89,6 +68,7 @@ func GetBroadcastMessenger( InterceptorsContainer: interceptorsContainer, AlarmScheduler: alarmScheduler, KeysHandler: keysHandler, + DelayedBroadcaster: delayedBroadcaster, } if shardCoordinator.SelfId() < shardCoordinator.NumberOfShards() { diff --git a/consensus/spos/sposFactory/sposFactory_test.go b/consensus/spos/sposFactory/sposFactory_test.go index 4a672a3343f..1f122884530 100644 --- a/consensus/spos/sposFactory/sposFactory_test.go +++ b/consensus/spos/sposFactory/sposFactory_test.go @@ -5,20 +5,18 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/sposFactory" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - "github.com/multiversx/mx-chain-go/testscommon/outport" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/testscommon/pool" ) -var currentPid = core.PeerID("pid") - func TestGetConsensusCoreFactory_InvalidTypeShouldErr(t *testing.T) { t.Parallel() @@ -37,98 +35,6 @@ func TestGetConsensusCoreFactory_BlsShouldWork(t *testing.T) { assert.False(t, check.IfNil(csf)) } -func TestGetSubroundsFactory_BlsNilConsensusCoreShouldErr(t *testing.T) { - t.Parallel() - - worker := &mock.SposWorkerMock{} - consensusType := consensus.BlsConsensusType - statusHandler := statusHandlerMock.NewAppStatusHandlerMock() - chainID := []byte("chain-id") - indexer := &outport.OutportStub{} - sf, err := sposFactory.GetSubroundsFactory( - nil, - &spos.ConsensusState{}, - worker, - consensusType, - statusHandler, - indexer, - &testscommon.SentSignatureTrackerStub{}, - chainID, - currentPid, - ) - - assert.Nil(t, sf) - assert.Equal(t, spos.ErrNilConsensusCore, err) -} - -func TestGetSubroundsFactory_BlsNilStatusHandlerShouldErr(t *testing.T) { - t.Parallel() - - consensusCore := mock.InitConsensusCore() - worker := &mock.SposWorkerMock{} - consensusType := consensus.BlsConsensusType - chainID := []byte("chain-id") - indexer := &outport.OutportStub{} - sf, err := sposFactory.GetSubroundsFactory( - consensusCore, - &spos.ConsensusState{}, - worker, - consensusType, - nil, - indexer, - &testscommon.SentSignatureTrackerStub{}, - chainID, - currentPid, - ) - - assert.Nil(t, sf) - assert.Equal(t, spos.ErrNilAppStatusHandler, err) -} - -func TestGetSubroundsFactory_BlsShouldWork(t *testing.T) { - t.Parallel() - - consensusCore := mock.InitConsensusCore() - worker := &mock.SposWorkerMock{} - consensusType := consensus.BlsConsensusType - statusHandler := statusHandlerMock.NewAppStatusHandlerMock() - chainID := []byte("chain-id") - indexer := &outport.OutportStub{} - sf, err := sposFactory.GetSubroundsFactory( - consensusCore, - &spos.ConsensusState{}, - worker, - consensusType, - statusHandler, - indexer, - &testscommon.SentSignatureTrackerStub{}, - chainID, - currentPid, - ) - assert.Nil(t, err) - assert.False(t, check.IfNil(sf)) -} - -func TestGetSubroundsFactory_InvalidConsensusTypeShouldErr(t *testing.T) { - t.Parallel() - - consensusType := "invalid" - sf, err := sposFactory.GetSubroundsFactory( - nil, - nil, - nil, - consensusType, - nil, - nil, - nil, - nil, - currentPid, - ) - - assert.Nil(t, sf) - assert.Equal(t, sposFactory.ErrInvalidConsensusType, err) -} - func TestGetBroadcastMessenger_ShardShouldWork(t *testing.T) { t.Parallel() @@ -140,9 +46,9 @@ func TestGetBroadcastMessenger_ShardShouldWork(t *testing.T) { return 0 } peerSigHandler := &mock.PeerSignatureHandler{} - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptosContainer := &testscommon.InterceptorsContainerStub{} - alarmSchedulerStub := &mock.AlarmSchedulerStub{} + alarmSchedulerStub := &testscommon.AlarmSchedulerStub{} bm, err := sposFactory.GetBroadcastMessenger( marshalizer, @@ -171,9 +77,9 @@ func TestGetBroadcastMessenger_MetachainShouldWork(t *testing.T) { return core.MetachainShardId } peerSigHandler := &mock.PeerSignatureHandler{} - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptosContainer := &testscommon.InterceptorsContainerStub{} - alarmSchedulerStub := &mock.AlarmSchedulerStub{} + alarmSchedulerStub := &testscommon.AlarmSchedulerStub{} bm, err := sposFactory.GetBroadcastMessenger( marshalizer, @@ -194,9 +100,9 @@ func TestGetBroadcastMessenger_MetachainShouldWork(t *testing.T) { func TestGetBroadcastMessenger_NilShardCoordinatorShouldErr(t *testing.T) { t.Parallel() - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptosContainer := &testscommon.InterceptorsContainerStub{} - alarmSchedulerStub := &mock.AlarmSchedulerStub{} + alarmSchedulerStub := &testscommon.AlarmSchedulerStub{} bm, err := sposFactory.GetBroadcastMessenger( nil, @@ -221,9 +127,9 @@ func TestGetBroadcastMessenger_InvalidShardIdShouldErr(t *testing.T) { shardCoord.SelfIDCalled = func() uint32 { return 37 } - headersSubscriber := &mock.HeadersCacherStub{} + headersSubscriber := &pool.HeadersPoolStub{} interceptosContainer := &testscommon.InterceptorsContainerStub{} - alarmSchedulerStub := &mock.AlarmSchedulerStub{} + alarmSchedulerStub := &testscommon.AlarmSchedulerStub{} bm, err := sposFactory.GetBroadcastMessenger( nil, diff --git a/consensus/spos/subround.go b/consensus/spos/subround.go index 1d1b07589a6..00b2c55fe6c 100644 --- a/consensus/spos/subround.go +++ b/consensus/spos/subround.go @@ -6,18 +6,24 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/consensus" ) var _ consensus.SubroundHandler = (*Subround)(nil) +const ( + singleKeyStartMsg = " (my turn)" + multiKeyStartMsg = " (my turn in multi-key)" +) + // Subround struct contains the needed data for one Subround and the Subround properties. It defines a Subround // with its properties (its ID, next Subround ID, its duration, its name) and also it has some handler functions // which should be set. Job function will be the main function of this Subround, Extend function will handle the overtime // situation of the Subround and Check function will decide if in this Subround the consensus is achieved type Subround struct { ConsensusCoreHandler - *ConsensusState + ConsensusStateHandler previous int current int @@ -45,7 +51,7 @@ func NewSubround( startTime int64, endTime int64, name string, - consensusState *ConsensusState, + consensusState ConsensusStateHandler, consensusStateChangedChannel chan bool, executeStoredMessages func(), container ConsensusCoreHandler, @@ -67,7 +73,7 @@ func NewSubround( sr := Subround{ ConsensusCoreHandler: container, - ConsensusState: consensusState, + ConsensusStateHandler: consensusState, previous: previous, current: current, next: next, @@ -88,7 +94,7 @@ func NewSubround( } func checkNewSubroundParams( - state *ConsensusState, + state ConsensusStateHandler, consensusStateChangedChannel chan bool, executeStoredMessages func(), container ConsensusCoreHandler, @@ -145,7 +151,7 @@ func (sr *Subround) DoWork(ctx context.Context, roundHandler consensus.RoundHand } case <-time.After(roundHandler.RemainingTime(startTime, maxTime)): if sr.Extend != nil { - sr.RoundCanceled = true + sr.SetRoundCanceled(true) sr.Extend(sr.current) } @@ -206,7 +212,7 @@ func (sr *Subround) ConsensusChannel() chan bool { // GetAssociatedPid returns the associated PeerID to the provided public key bytes func (sr *Subround) GetAssociatedPid(pkBytes []byte) core.PeerID { - return sr.keysHandler.GetAssociatedPid(pkBytes) + return sr.GetKeysHandler().GetAssociatedPid(pkBytes) } // ShouldConsiderSelfKeyInConsensus returns true if current machine is the main one, or it is a backup machine but the main @@ -221,6 +227,36 @@ func (sr *Subround) ShouldConsiderSelfKeyInConsensus() bool { return isMainMachineInactive } +// IsSelfInConsensusGroup returns true is the current node is in consensus group in single +// key or in multi-key mode +func (sr *Subround) IsSelfInConsensusGroup() bool { + return sr.IsNodeInConsensusGroup(sr.SelfPubKey()) || sr.IsMultiKeyInConsensusGroup() +} + +// IsSelfLeader returns true is the current node is leader is single key or in +// multi-key mode +func (sr *Subround) IsSelfLeader() bool { + return sr.IsSelfLeaderInCurrentRound() || sr.IsMultiKeyLeaderInCurrentRound() +} + +// IsSelfLeaderInCurrentRound method checks if the current node is leader in the current round +func (sr *Subround) IsSelfLeaderInCurrentRound() bool { + return sr.IsNodeLeaderInCurrentRound(sr.SelfPubKey()) && sr.ShouldConsiderSelfKeyInConsensus() +} + +// GetLeaderStartRoundMessage returns the leader start round message based on single key +// or multi-key node type +func (sr *Subround) GetLeaderStartRoundMessage() string { + if sr.IsMultiKeyLeaderInCurrentRound() { + return multiKeyStartMsg + } + if sr.IsSelfLeaderInCurrentRound() { + return singleKeyStartMsg + } + + return "" +} + // IsInterfaceNil returns true if there is no value under the interface func (sr *Subround) IsInterfaceNil() bool { return sr == nil diff --git a/consensus/spos/subround_test.go b/consensus/spos/subround_test.go index 202899e1a24..a07b3d460fd 100644 --- a/consensus/spos/subround_test.go +++ b/consensus/spos/subround_test.go @@ -1,19 +1,24 @@ package spos_test import ( + "bytes" "context" "sync" "testing" "time" "github.com/multiversx/mx-chain-core-go/core" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/consensus/spos/bls" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) var chainID = []byte("chain ID") @@ -57,6 +62,7 @@ func initConsensusState() *spos.ConsensusState { ) rcns.SetConsensusGroup(eligibleList) + rcns.SetLeader(eligibleList[indexLeader]) rcns.ResetRoundState() pBFTThreshold := consensusGroupSize*2/3 + 1 @@ -84,14 +90,14 @@ func initConsensusState() *spos.ConsensusState { ) cns.Data = []byte("X") - cns.RoundIndex = 0 + cns.SetRoundIndex(0) return cns } func TestSubround_NewSubroundNilConsensusStateShouldFail(t *testing.T) { t.Parallel() - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() ch := make(chan bool, 1) sr, err := spos.NewSubround( @@ -118,7 +124,7 @@ func TestSubround_NewSubroundNilChannelShouldFail(t *testing.T) { t.Parallel() consensusState := initConsensusState() - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, err := spos.NewSubround( -1, @@ -144,7 +150,7 @@ func TestSubround_NewSubroundNilExecuteStoredMessagesShouldFail(t *testing.T) { t.Parallel() consensusState := initConsensusState() - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() ch := make(chan bool, 1) sr, err := spos.NewSubround( @@ -198,7 +204,7 @@ func TestSubround_NilContainerBlockchainShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetBlockchain(nil) sr, err := spos.NewSubround( @@ -226,7 +232,7 @@ func TestSubround_NilContainerBlockprocessorShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetBlockProcessor(nil) sr, err := spos.NewSubround( @@ -254,7 +260,7 @@ func TestSubround_NilContainerBootstrapperShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetBootStrapper(nil) sr, err := spos.NewSubround( @@ -282,7 +288,7 @@ func TestSubround_NilContainerChronologyShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetChronology(nil) sr, err := spos.NewSubround( @@ -310,7 +316,7 @@ func TestSubround_NilContainerHasherShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetHasher(nil) sr, err := spos.NewSubround( @@ -338,7 +344,7 @@ func TestSubround_NilContainerMarshalizerShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetMarshalizer(nil) sr, err := spos.NewSubround( @@ -366,7 +372,7 @@ func TestSubround_NilContainerMultiSignerShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetMultiSignerContainer(cryptoMocks.NewMultiSignerContainerMock(nil)) sr, err := spos.NewSubround( @@ -394,7 +400,7 @@ func TestSubround_NilContainerRoundHandlerShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetRoundHandler(nil) sr, err := spos.NewSubround( @@ -422,7 +428,7 @@ func TestSubround_NilContainerShardCoordinatorShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetShardCoordinator(nil) sr, err := spos.NewSubround( @@ -450,7 +456,7 @@ func TestSubround_NilContainerSyncTimerShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetSyncTimer(nil) sr, err := spos.NewSubround( @@ -478,8 +484,8 @@ func TestSubround_NilContainerValidatorGroupSelectorShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() - container.SetValidatorGroupSelector(nil) + container := consensus.InitConsensusCore() + container.SetNodesCoordinator(nil) sr, err := spos.NewSubround( -1, @@ -506,7 +512,7 @@ func TestSubround_EmptyChainIDShouldFail(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, err := spos.NewSubround( -1, bls.SrStartRound, @@ -532,7 +538,7 @@ func TestSubround_NewSubroundShouldWork(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, err := spos.NewSubround( -1, bls.SrStartRound, @@ -566,7 +572,7 @@ func TestSubround_DoWorkShouldReturnFalseWhenJobFunctionIsNotSet(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( -1, @@ -589,7 +595,7 @@ func TestSubround_DoWorkShouldReturnFalseWhenJobFunctionIsNotSet(t *testing.T) { } maxTime := time.Now().Add(100 * time.Millisecond) - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensus.RoundHandlerMock{} roundHandlerMock.RemainingTimeCalled = func(time.Time, time.Duration) time.Duration { return time.Until(maxTime) } @@ -604,7 +610,7 @@ func TestSubround_DoWorkShouldReturnFalseWhenCheckFunctionIsNotSet(t *testing.T) consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( -1, @@ -627,7 +633,7 @@ func TestSubround_DoWorkShouldReturnFalseWhenCheckFunctionIsNotSet(t *testing.T) sr.Check = nil maxTime := time.Now().Add(100 * time.Millisecond) - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensus.RoundHandlerMock{} roundHandlerMock.RemainingTimeCalled = func(time.Time, time.Duration) time.Duration { return time.Until(maxTime) } @@ -651,7 +657,7 @@ func TestSubround_DoWorkShouldReturnTrueWhenJobAndConsensusAreDone(t *testing.T) func testDoWork(t *testing.T, checkDone bool, shouldWork bool) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( -1, @@ -676,7 +682,7 @@ func testDoWork(t *testing.T, checkDone bool, shouldWork bool) { } maxTime := time.Now().Add(100 * time.Millisecond) - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensus.RoundHandlerMock{} roundHandlerMock.RemainingTimeCalled = func(time.Time, time.Duration) time.Duration { return time.Until(maxTime) } @@ -690,7 +696,7 @@ func TestSubround_DoWorkShouldReturnTrueWhenJobIsDoneAndConsensusIsDoneAfterAWhi consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( -1, @@ -723,7 +729,7 @@ func TestSubround_DoWorkShouldReturnTrueWhenJobIsDoneAndConsensusIsDoneAfterAWhi } maxTime := time.Now().Add(2000 * time.Millisecond) - roundHandlerMock := &mock.RoundHandlerMock{} + roundHandlerMock := &consensus.RoundHandlerMock{} roundHandlerMock.RemainingTimeCalled = func(time.Time, time.Duration) time.Duration { return time.Until(maxTime) } @@ -748,7 +754,7 @@ func TestSubround_Previous(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( bls.SrStartRound, @@ -780,7 +786,7 @@ func TestSubround_Current(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( bls.SrStartRound, @@ -812,7 +818,7 @@ func TestSubround_Next(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( bls.SrStartRound, @@ -844,7 +850,7 @@ func TestSubround_StartTime(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetRoundHandler(initRoundHandlerMock()) sr, _ := spos.NewSubround( bls.SrBlock, @@ -876,7 +882,7 @@ func TestSubround_EndTime(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() container.SetRoundHandler(initRoundHandlerMock()) sr, _ := spos.NewSubround( bls.SrStartRound, @@ -908,7 +914,7 @@ func TestSubround_Name(t *testing.T) { consensusState := initConsensusState() ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() sr, _ := spos.NewSubround( bls.SrStartRound, @@ -941,7 +947,7 @@ func TestSubround_GetAssociatedPid(t *testing.T) { keysHandler := &testscommon.KeysHandlerStub{} consensusState := internalInitConsensusStateWithKeysHandler(keysHandler) ch := make(chan bool, 1) - container := mock.InitConsensusCore() + container := consensus.InitConsensusCore() subround, _ := spos.NewSubround( bls.SrStartRound, @@ -971,3 +977,370 @@ func TestSubround_GetAssociatedPid(t *testing.T) { assert.Equal(t, pid, subround.GetAssociatedPid(providedPkBytes)) assert.True(t, wasCalled) } + +func TestSubround_ShouldConsiderSelfKeyInConsensus(t *testing.T) { + t.Parallel() + + t.Run("is main machine active, should return true", func(t *testing.T) { + t.Parallel() + + consensusState := initConsensusState() + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + redundancyHandler := &mock.NodeRedundancyHandlerStub{ + IsRedundancyNodeCalled: func() bool { + return false + }, + IsMainMachineActiveCalled: func() bool { + return true + }, + } + container.SetNodeRedundancyHandler(redundancyHandler) + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + require.True(t, sr.ShouldConsiderSelfKeyInConsensus()) + }) + + t.Run("is redundancy node machine active, should return true", func(t *testing.T) { + t.Parallel() + + consensusState := initConsensusState() + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + redundancyHandler := &mock.NodeRedundancyHandlerStub{ + IsRedundancyNodeCalled: func() bool { + return true + }, + IsMainMachineActiveCalled: func() bool { + return false + }, + } + container.SetNodeRedundancyHandler(redundancyHandler) + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + require.True(t, sr.ShouldConsiderSelfKeyInConsensus()) + }) + + t.Run("is redundancy node machine but inactive, should return false", func(t *testing.T) { + t.Parallel() + + consensusState := initConsensusState() + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + redundancyHandler := &mock.NodeRedundancyHandlerStub{ + IsRedundancyNodeCalled: func() bool { + return true + }, + IsMainMachineActiveCalled: func() bool { + return true + }, + } + container.SetNodeRedundancyHandler(redundancyHandler) + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + require.False(t, sr.ShouldConsiderSelfKeyInConsensus()) + }) +} + +func TestSubround_GetLeaderStartRoundMessage(t *testing.T) { + t.Parallel() + + t.Run("should work with multi key node", func(t *testing.T) { + t.Parallel() + + keysHandler := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal([]byte("1"), pkBytes) + }, + } + consensusState := internalInitConsensusStateWithKeysHandler(keysHandler) + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetSelfPubKey("1") + + require.Equal(t, spos.LeaderMultiKeyStartMsg, sr.GetLeaderStartRoundMessage()) + }) + + t.Run("should work with single key node", func(t *testing.T) { + t.Parallel() + + keysHandler := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal([]byte("2"), pkBytes) + }, + } + consensusState := internalInitConsensusStateWithKeysHandler(keysHandler) + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetSelfPubKey("1") + + require.Equal(t, spos.LeaderSingleKeyStartMsg, sr.GetLeaderStartRoundMessage()) + }) + + t.Run("should return empty string when leader is not managed by current node", func(t *testing.T) { + t.Parallel() + + keysHandler := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return false + }, + } + consensusState := internalInitConsensusStateWithKeysHandler(keysHandler) + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetSelfPubKey("5") + + require.Equal(t, "", sr.GetLeaderStartRoundMessage()) + }) +} + +func TestSubround_IsSelfInConsensusGroup(t *testing.T) { + t.Parallel() + + t.Run("should work with multi key node", func(t *testing.T) { + t.Parallel() + + keysHandler := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal([]byte("1"), pkBytes) + }, + } + consensusState := internalInitConsensusStateWithKeysHandler(keysHandler) + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + require.True(t, sr.IsSelfInConsensusGroup()) + }) + + t.Run("should work with single key node", func(t *testing.T) { + t.Parallel() + + consensusState := internalInitConsensusStateWithKeysHandler(&testscommon.KeysHandlerStub{}) + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetSelfPubKey("1") + + require.True(t, sr.IsSelfInConsensusGroup()) + }) +} + +func TestSubround_IsSelfLeader(t *testing.T) { + t.Parallel() + + t.Run("should work with multi key node", func(t *testing.T) { + t.Parallel() + + keysHandler := &testscommon.KeysHandlerStub{ + IsKeyManagedByCurrentNodeCalled: func(pkBytes []byte) bool { + return bytes.Equal([]byte("1"), pkBytes) + }, + } + consensusState := internalInitConsensusStateWithKeysHandler(keysHandler) + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + + sr.SetLeader("1") + + require.True(t, sr.IsSelfLeader()) + }) + + t.Run("should work with single key node", func(t *testing.T) { + t.Parallel() + + consensusState := internalInitConsensusStateWithKeysHandler(&testscommon.KeysHandlerStub{}) + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + sr, _ := spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + sr.SetSelfPubKey("1") + sr.SetLeader("1") + + require.True(t, sr.IsSelfLeader()) + }) +} + +func TestSubround_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var sr *spos.Subround + require.True(t, sr.IsInterfaceNil()) + + consensusState := internalInitConsensusStateWithKeysHandler(&testscommon.KeysHandlerStub{}) + ch := make(chan bool, 1) + container := consensus.InitConsensusCore() + + sr, _ = spos.NewSubround( + bls.SrStartRound, + bls.SrBlock, + bls.SrSignature, + int64(5*roundTimeDuration/100), + int64(25*roundTimeDuration/100), + "(BLOCK)", + consensusState, + ch, + executeStoredMessages, + container, + chainID, + currentPid, + &statusHandler.AppStatusHandlerStub{}, + ) + require.False(t, sr.IsInterfaceNil()) +} diff --git a/consensus/spos/worker.go b/consensus/spos/worker.go index f11e40d3089..54c493fd7ed 100644 --- a/consensus/spos/worker.go +++ b/consensus/spos/worker.go @@ -35,6 +35,10 @@ var _ closing.Closer = (*Worker)(nil) const sleepTime = 5 * time.Millisecond const redundancySingleKeySteppedIn = "single-key node stepped in" +type blockProcessorWithPool interface { + RemoveHeaderFromPool(headerHash []byte) +} + // Worker defines the data needed by spos to communicate between nodes which are in the validators group type Worker struct { consensusService ConsensusService @@ -54,11 +58,12 @@ type Worker struct { headerSigVerifier HeaderSigVerifier headerIntegrityVerifier process.HeaderIntegrityVerifier appStatusHandler core.AppStatusHandler + enableEpochsHandler common.EnableEpochsHandler networkShardingCollector consensus.NetworkShardingCollector receivedMessages map[consensus.MessageType][]*consensus.Message - receivedMessagesCalls map[consensus.MessageType]func(ctx context.Context, msg *consensus.Message) bool + receivedMessagesCalls map[consensus.MessageType][]func(ctx context.Context, msg *consensus.Message) bool executeMessageChannel chan *consensus.Message consensusStateChangedChannel chan bool @@ -72,6 +77,9 @@ type Worker struct { receivedHeadersHandlers []func(headerHandler data.HeaderHandler) mutReceivedHeadersHandler sync.RWMutex + receivedProofHandlers []func(proofHandler consensus.ProofHandler) + mutReceivedProofHandler sync.RWMutex + antifloodHandler consensus.P2PAntifloodHandler poolAdder PoolAdder @@ -80,6 +88,8 @@ type Worker struct { nodeRedundancyHandler consensus.NodeRedundancyHandler peerBlacklistHandler consensus.PeerBlacklistHandler closer core.SafeCloser + + invalidSignersCache InvalidSignersCache } // WorkerArgs holds the consensus worker arguments @@ -109,6 +119,8 @@ type WorkerArgs struct { AppStatusHandler core.AppStatusHandler NodeRedundancyHandler consensus.NodeRedundancyHandler PeerBlacklistHandler consensus.PeerBlacklistHandler + EnableEpochsHandler common.EnableEpochsHandler + InvalidSignersCache InvalidSignersCache } // NewWorker creates a new Worker object @@ -122,6 +134,9 @@ func NewWorker(args *WorkerArgs) (*Worker, error) { ConsensusState: args.ConsensusState, ConsensusService: args.ConsensusService, PeerSignatureHandler: args.PeerSignatureHandler, + EnableEpochsHandler: args.EnableEpochsHandler, + Marshaller: args.Marshalizer, + ShardCoordinator: args.ShardCoordinator, SignatureSize: args.SignatureSize, PublicKeySize: args.PublicKeySize, HeaderHashSize: args.Hasher.Size(), @@ -157,11 +172,13 @@ func NewWorker(args *WorkerArgs) (*Worker, error) { nodeRedundancyHandler: args.NodeRedundancyHandler, peerBlacklistHandler: args.PeerBlacklistHandler, closer: closing.NewSafeChanCloser(), + enableEpochsHandler: args.EnableEpochsHandler, + invalidSignersCache: args.InvalidSignersCache, } wrk.consensusMessageValidator = consensusMessageValidatorObj wrk.executeMessageChannel = make(chan *consensus.Message) - wrk.receivedMessagesCalls = make(map[consensus.MessageType]func(context.Context, *consensus.Message) bool) + wrk.receivedMessagesCalls = make(map[consensus.MessageType][]func(context.Context, *consensus.Message) bool) wrk.receivedHeadersHandlers = make([]func(data.HeaderHandler), 0) wrk.consensusStateChangedChannel = make(chan bool, 1) wrk.bootstrapper.AddSyncStateListener(wrk.receivedSyncState) @@ -257,6 +274,12 @@ func checkNewWorkerParams(args *WorkerArgs) error { if check.IfNil(args.PeerBlacklistHandler) { return ErrNilPeerBlacklistHandler } + if check.IfNil(args.EnableEpochsHandler) { + return ErrNilEnableEpochsHandler + } + if check.IfNil(args.InvalidSignersCache) { + return ErrNilInvalidSignersCache + } return nil } @@ -270,12 +293,84 @@ func (wrk *Worker) receivedSyncState(isNodeSynchronized bool) { } } +func (wrk *Worker) addFutureHeaderToProcessIfNeeded(header data.HeaderHandler) { + if check.IfNil(header) { + return + } + if !wrk.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + return + } + + isHeaderForNextRound := int64(header.GetRound()) == wrk.roundHandler.Index()+1 + if !isHeaderForNextRound { + return + } + + headerConsensusMessage, err := wrk.convertHeaderToConsensusMessage(header) + if err != nil { + log.Error("addFutureHeaderToProcessIfNeeded: convertHeaderToConsensusMessage failed", "error", err.Error()) + return + } + + go wrk.executeReceivedMessages(headerConsensusMessage) +} + +func (wrk *Worker) processReceivedHeaderMetricIfNeeded(header data.HeaderHandler) { + if check.IfNil(header) { + return + } + if !wrk.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + return + } + isHeaderForCurrentRound := int64(header.GetRound()) == wrk.roundHandler.Index() + if !isHeaderForCurrentRound { + return + } + isHeaderFromCurrentShard := header.GetShardID() == wrk.shardCoordinator.SelfId() + if !isHeaderFromCurrentShard { + return + } + + wrk.processReceivedHeaderMetric() +} + +func (wrk *Worker) convertHeaderToConsensusMessage(header data.HeaderHandler) (*consensus.Message, error) { + headerBytes, err := wrk.marshalizer.Marshal(header) + if err != nil { + return nil, ErrInvalidHeader + } + + return &consensus.Message{ + Header: headerBytes, + MsgType: int64(wrk.consensusService.GetMessageTypeBlockHeader()), + RoundIndex: int64(header.GetRound()), + }, nil +} + // ReceivedHeader process the received header, calling each received header handler registered in worker instance func (wrk *Worker) ReceivedHeader(headerHandler data.HeaderHandler, _ []byte) { + if check.IfNil(headerHandler) { + log.Trace("ReceivedHeader: nil header handler") + return + } + isHeaderForOtherShard := headerHandler.GetShardID() != wrk.shardCoordinator.SelfId() + if isHeaderForOtherShard { + log.Trace("ReceivedHeader: received header for other shard", + "self shardID", wrk.shardCoordinator.SelfId(), + "received shardID", headerHandler.GetShardID(), + ) + return + } + + wrk.addFutureHeaderToProcessIfNeeded(headerHandler) + wrk.processReceivedHeaderMetricIfNeeded(headerHandler) isHeaderForOtherRound := int64(headerHandler.GetRound()) != wrk.roundHandler.Index() - headerCanNotBeProcessed := isHeaderForOtherShard || isHeaderForOtherRound - if headerCanNotBeProcessed { + if isHeaderForOtherRound { + log.Trace("ReceivedHeader: received header for other round", + "self round", wrk.roundHandler.Index(), + "received round", headerHandler.GetRound(), + ) return } @@ -298,23 +393,53 @@ func (wrk *Worker) AddReceivedHeaderHandler(handler func(data.HeaderHandler)) { wrk.mutReceivedHeadersHandler.Unlock() } +// RemoveAllReceivedHeaderHandlers removes all the functions handlers +func (wrk *Worker) RemoveAllReceivedHeaderHandlers() { + wrk.mutReceivedHeadersHandler.Lock() + wrk.receivedHeadersHandlers = make([]func(data.HeaderHandler), 0) + wrk.mutReceivedHeadersHandler.Unlock() +} + +// ReceivedProof process the received proof, calling each received proof handler registered in worker instance +func (wrk *Worker) ReceivedProof(proofHandler consensus.ProofHandler) { + if check.IfNil(proofHandler) { + log.Trace("ReceivedProof: nil proof handler") + return + } + + log.Trace("ReceivedProof:", "proof header", proofHandler.GetHeaderHash()) + + wrk.mutReceivedProofHandler.RLock() + for _, handler := range wrk.receivedProofHandlers { + handler(proofHandler) + } + wrk.mutReceivedProofHandler.RUnlock() +} + +// AddReceivedProofHandler adds a new handler function for a received proof +func (wrk *Worker) AddReceivedProofHandler(handler func(proofHandler consensus.ProofHandler)) { + wrk.mutReceivedProofHandler.Lock() + wrk.receivedProofHandlers = append(wrk.receivedProofHandlers, handler) + wrk.mutReceivedProofHandler.Unlock() +} + func (wrk *Worker) initReceivedMessages() { wrk.mutReceivedMessages.Lock() wrk.receivedMessages = wrk.consensusService.InitReceivedMessages() wrk.mutReceivedMessages.Unlock() } -// AddReceivedMessageCall adds a new handler function for a received messege type +// AddReceivedMessageCall adds a new handler function for a received message type func (wrk *Worker) AddReceivedMessageCall(messageType consensus.MessageType, receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool) { wrk.mutReceivedMessagesCalls.Lock() - wrk.receivedMessagesCalls[messageType] = receivedMessageCall + wrk.receivedMessagesCalls[messageType] = append(wrk.receivedMessagesCalls[messageType], receivedMessageCall) wrk.mutReceivedMessagesCalls.Unlock() } // RemoveAllReceivedMessagesCalls removes all the functions handlers func (wrk *Worker) RemoveAllReceivedMessagesCalls() { wrk.mutReceivedMessagesCalls.Lock() - wrk.receivedMessagesCalls = make(map[consensus.MessageType]func(context.Context, *consensus.Message) bool) + wrk.receivedMessagesCalls = make(map[consensus.MessageType][]func(context.Context, *consensus.Message) bool) wrk.mutReceivedMessagesCalls.Unlock() } @@ -337,15 +462,15 @@ func (wrk *Worker) getCleanedList(cnsDataList []*consensus.Message) []*consensus } // ProcessReceivedMessage method redirects the received message to the channel which should handle it -func (wrk *Worker) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, _ p2p.MessageHandler) error { +func (wrk *Worker) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, _ p2p.MessageHandler) ([]byte, error) { if check.IfNil(message) { - return ErrNilMessage + return nil, ErrNilMessage } if message.Data() == nil { - return ErrNilDataToProcess + return nil, ErrNilDataToProcess } if len(message.Signature()) == 0 { - return ErrNilSignatureOnP2PMessage + return nil, ErrNilSignatureOnP2PMessage } isPeerBlacklisted := wrk.peerBlacklistHandler.IsPeerBlacklisted(fromConnectedPeer) @@ -353,13 +478,13 @@ func (wrk *Worker) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedP log.Debug("received message from blacklisted peer", "peer", fromConnectedPeer.Pretty(), ) - return ErrBlacklistedConsensusPeer + return nil, ErrBlacklistedConsensusPeer } topic := GetConsensusTopicID(wrk.shardCoordinator) err := wrk.antifloodHandler.CanProcessMessagesOnTopic(message.Peer(), topic, 1, uint64(len(message.Data())), message.SeqNo()) if err != nil { - return err + return nil, err } defer func() { @@ -376,7 +501,7 @@ func (wrk *Worker) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedP cnsMsg := &consensus.Message{} err = wrk.marshalizer.Unmarshal(cnsMsg, message.Data()) if err != nil { - return err + return nil, err } wrk.consensusState.ResetRoundsWithoutReceivedMessages(cnsMsg.GetPubKey(), message.Peer()) @@ -389,26 +514,18 @@ func (wrk *Worker) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedP ) } - msgType := consensus.MessageType(cnsMsg.MsgType) - - log.Trace("received message from consensus topic", - "msg type", wrk.consensusService.GetStringValue(msgType), - "from", cnsMsg.PubKey, - "header hash", cnsMsg.BlockHeaderHash, - "round", cnsMsg.RoundIndex, - "size", len(message.Data()), - ) - - err = wrk.consensusMessageValidator.checkConsensusMessageValidity(cnsMsg, message.Peer()) + err = wrk.checkValidityAndProcessFinalInfo(cnsMsg, message) if err != nil { - return err + return nil, err } wrk.networkShardingCollector.UpdatePeerIDInfo(message.Peer(), cnsMsg.PubKey, wrk.shardCoordinator.SelfId()) + msgType := consensus.MessageType(cnsMsg.MsgType) isMessageWithBlockBody := wrk.consensusService.IsMessageWithBlockBody(msgType) isMessageWithBlockHeader := wrk.consensusService.IsMessageWithBlockHeader(msgType) isMessageWithBlockBodyAndHeader := wrk.consensusService.IsMessageWithBlockBodyAndHeader(msgType) + isMessageWithInvalidSigners := wrk.consensusService.IsMessageWithInvalidSigners(msgType) if isMessageWithBlockBody || isMessageWithBlockBodyAndHeader { wrk.doJobOnMessageWithBlockBody(cnsMsg) @@ -417,7 +534,7 @@ func (wrk *Worker) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedP if isMessageWithBlockHeader || isMessageWithBlockBodyAndHeader { err = wrk.doJobOnMessageWithHeader(cnsMsg) if err != nil { - return err + return nil, err } } @@ -425,17 +542,24 @@ func (wrk *Worker) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedP wrk.doJobOnMessageWithSignature(cnsMsg, message) } + if isMessageWithInvalidSigners { + err = wrk.verifyMessageWithInvalidSigners(cnsMsg) + if err != nil { + return nil, err + } + } + errNotCritical := wrk.checkSelfState(cnsMsg) if errNotCritical != nil { log.Trace("checkSelfState", "error", errNotCritical.Error()) // in this case should return nil but do not process the message // nil error will mean that the interceptor will validate this message and broadcast it to the connected peers - return nil + return []byte{}, nil } go wrk.executeReceivedMessages(cnsMsg) - return nil + return []byte{}, nil } func (wrk *Worker) shouldBlacklistPeer(err error) bool { @@ -446,7 +570,9 @@ func (wrk *Worker) shouldBlacklistPeer(err error) bool { errors.Is(err, errorsErd.ErrPIDMismatch) || errors.Is(err, errorsErd.ErrSignatureMismatch) || errors.Is(err, nodesCoordinator.ErrEpochNodesConfigDoesNotExist) || - errors.Is(err, ErrMessageTypeLimitReached) { + errors.Is(err, ErrMessageTypeLimitReached) || + errors.Is(err, ErrEquivalentMessageAlreadyReceived) || + errors.Is(err, ErrInvalidSignersAlreadyReceived) { return false } @@ -503,7 +629,7 @@ func (wrk *Worker) doJobOnMessageWithHeader(cnsMsg *consensus.Message) error { err) } - wrk.processReceivedHeaderMetric(cnsMsg) + wrk.processReceivedHeaderMetricForConsensusMessage(cnsMsg) errNotCritical := wrk.forkDetector.AddHeader(header, headerHash, process.BHProposed, nil, nil) if errNotCritical != nil { @@ -516,6 +642,16 @@ func (wrk *Worker) doJobOnMessageWithHeader(cnsMsg *consensus.Message) error { return nil } +func (wrk *Worker) verifyMessageWithInvalidSigners(cnsMsg *consensus.Message) error { + // No need to guard this method by verification of common.AndromedaFlag as invalidSignersCache will have entries only for consensus v2 + if wrk.invalidSignersCache.CheckKnownInvalidSigners(cnsMsg.BlockHeaderHash, cnsMsg.InvalidSigners) { + // return error here to avoid further broadcast of this message + return ErrInvalidSignersAlreadyReceived + } + + return nil +} + func (wrk *Worker) verifyHeaderHash(hash []byte, marshalledHeader []byte) bool { computedHash := wrk.hasher.Compute(string(marshalledHeader)) return bytes.Equal(hash, computedHash) @@ -529,6 +665,11 @@ func (wrk *Worker) doJobOnMessageWithSignature(cnsMsg *consensus.Message, p2pMsg wrk.mapDisplayHashConsensusMessage[hash] = append(wrk.mapDisplayHashConsensusMessage[hash], cnsMsg) wrk.consensusState.AddMessageWithSignature(string(cnsMsg.PubKey), p2pMsg) + + log.Trace("received message with signature", + "from", core.GetTrimmedPk(hex.EncodeToString(cnsMsg.PubKey)), + "header hash", cnsMsg.BlockHeaderHash, + ) } func (wrk *Worker) addBlockToPool(bodyBytes []byte) { @@ -547,8 +688,16 @@ func (wrk *Worker) addBlockToPool(bodyBytes []byte) { } } -func (wrk *Worker) processReceivedHeaderMetric(cnsDta *consensus.Message) { - if wrk.consensusState.ConsensusGroup() == nil || !wrk.consensusState.IsNodeLeaderInCurrentRound(string(cnsDta.PubKey)) { +func (wrk *Worker) processReceivedHeaderMetricForConsensusMessage(cnsDta *consensus.Message) { + if !wrk.consensusState.IsNodeLeaderInCurrentRound(string(cnsDta.PubKey)) { + return + } + + wrk.processReceivedHeaderMetric() +} + +func (wrk *Worker) processReceivedHeaderMetric() { + if wrk.consensusState.ConsensusGroup() == nil { return } @@ -580,7 +729,7 @@ func (wrk *Worker) checkSelfState(cnsDta *consensus.Message) error { return ErrMessageFromItself } - if wrk.consensusState.RoundCanceled && wrk.consensusState.RoundIndex == cnsDta.RoundIndex { + if wrk.consensusState.GetRoundCanceled() && wrk.consensusState.GetRoundIndex() == cnsDta.RoundIndex { return ErrRoundCanceled } @@ -616,7 +765,7 @@ func (wrk *Worker) executeMessage(cnsDtaList []*consensus.Message) { if cnsDta == nil { continue } - if wrk.consensusState.RoundIndex != cnsDta.RoundIndex { + if wrk.consensusState.GetRoundIndex() != cnsDta.RoundIndex { continue } @@ -652,20 +801,47 @@ func (wrk *Worker) checkChannels(ctx context.Context) { msgType := consensus.MessageType(rcvDta.MsgType) - if callReceivedMessage, exist := wrk.receivedMessagesCalls[msgType]; exist { - if callReceivedMessage(ctx, rcvDta) { - select { - case wrk.consensusStateChangedChannel <- true: - default: + if receivedMessageCallbacks, exist := wrk.receivedMessagesCalls[msgType]; exist { + for _, callReceivedMessage := range receivedMessageCallbacks { + if callReceivedMessage(ctx, rcvDta) { + select { + case wrk.consensusStateChangedChannel <- true: + default: + } } } } + + wrk.callReceivedHeaderCallbacks(rcvDta) + } +} + +func (wrk *Worker) callReceivedHeaderCallbacks(message *consensus.Message) { + headerMessageType := wrk.consensusService.GetMessageTypeBlockHeader() + if message.MsgType != int64(headerMessageType) || !wrk.enableEpochsHandler.IsFlagEnabled(common.AndromedaFlag) { + return + } + + header := wrk.blockProcessor.DecodeBlockHeader(message.Header) + if check.IfNil(header) { + return + } + + wrk.mutReceivedHeadersHandler.RLock() + for _, handler := range wrk.receivedHeadersHandlers { + handler(header) + } + wrk.mutReceivedHeadersHandler.RUnlock() + + select { + case wrk.consensusStateChangedChannel <- true: + default: } } // Extend does an extension for the subround with subroundId func (wrk *Worker) Extend(subroundId int) { - wrk.consensusState.ExtendedCalled = true + wrk.consensusState.SetExtendedCalled(true) log.Debug("extend function is called", "subround", wrk.consensusService.GetSubroundName(subroundId)) @@ -681,9 +857,36 @@ func (wrk *Worker) Extend(subroundId int) { wrk.scheduledProcessor.ForceStopScheduledExecutionBlocking() wrk.blockProcessor.RevertCurrentBlock() + wrk.removeConsensusHeaderFromPool() + log.Debug("current block is reverted") } +func (wrk *Worker) removeConsensusHeaderFromPool() { + headerHash := wrk.consensusState.GetData() + if len(headerHash) == 0 { + return + } + + header := wrk.consensusState.GetHeader() + if check.IfNil(header) { + return + } + + if !wrk.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + return + } + + blockProcessorWithPoolAccess, ok := wrk.blockProcessor.(blockProcessorWithPool) + if !ok { + log.Error("removeConsensusHeaderFromPool: blockProcessorWithPoolAccess is nil") + return + } + + blockProcessorWithPoolAccess.RemoveHeaderFromPool(headerHash) + wrk.forkDetector.RemoveHeader(header.GetNonce(), headerHash) +} + // DisplayStatistics logs the consensus messages split on proposed headers func (wrk *Worker) DisplayStatistics() { wrk.mutDisplayHashConsensusMessage.Lock() @@ -732,11 +935,35 @@ func (wrk *Worker) Close() error { return nil } -// ResetConsensusMessages resets at the start of each round all the previous consensus messages received +// ResetConsensusMessages resets at the start of each round all the previous consensus messages received and equivalent messages, keeping the provided proofs func (wrk *Worker) ResetConsensusMessages() { wrk.consensusMessageValidator.resetConsensusMessages() } +// ResetConsensusRoundState resets the consensus round state +func (wrk *Worker) ResetConsensusRoundState() { + wrk.consensusState.ResetConsensusRoundState() +} + +// ResetInvalidSignersCache resets the invalid signers cache +func (wrk *Worker) ResetInvalidSignersCache() { + wrk.invalidSignersCache.Reset() +} + +func (wrk *Worker) checkValidityAndProcessFinalInfo(cnsMsg *consensus.Message, p2pMessage p2p.MessageP2P) error { + msgType := consensus.MessageType(cnsMsg.MsgType) + + log.Trace("received message from consensus topic", + "msg type", wrk.consensusService.GetStringValue(msgType), + "from", cnsMsg.PubKey, + "header hash", cnsMsg.BlockHeaderHash, + "round", cnsMsg.RoundIndex, + "size", len(p2pMessage.Data()), + ) + + return wrk.consensusMessageValidator.checkConsensusMessageValidity(cnsMsg, p2pMessage.Peer()) +} + // IsInterfaceNil returns true if there is no value under the interface func (wrk *Worker) IsInterfaceNil() bool { return wrk == nil diff --git a/consensus/spos/worker_test.go b/consensus/spos/worker_test.go index b179fdf0db8..a144b88dcff 100644 --- a/consensus/spos/worker_test.go +++ b/consensus/spos/worker_test.go @@ -27,8 +27,13 @@ import ( "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/bootstrapperStubs" + "github.com/multiversx/mx-chain-go/testscommon/cache" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" + "github.com/multiversx/mx-chain-go/testscommon/processMocks" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" ) @@ -58,14 +63,14 @@ func createDefaultWorkerArgs(appStatusHandler core.AppStatusHandler) *spos.Worke return nil }, } - bootstrapperMock := &mock.BootstrapperStub{} - broadcastMessengerMock := &mock.BroadcastMessengerMock{} + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} + broadcastMessengerMock := &consensusMocks.BroadcastMessengerMock{} consensusState := initConsensusState() - forkDetectorMock := &mock.ForkDetectorMock{} + forkDetectorMock := &processMocks.ForkDetectorStub{} forkDetectorMock.AddHeaderCalled = func(header data.HeaderHandler, hash []byte, state process.BlockHeaderState, selfNotarizedHeaders []data.HeaderHandler, selfNotarizedHeadersHashes [][]byte) error { return nil } - keyGeneratorMock, _, _ := mock.InitKeys() + keyGeneratorMock, _, _ := consensusMocks.InitKeys() marshalizerMock := mock.MarshalizerMock{} roundHandlerMock := initRoundHandlerMock() shardCoordinatorMock := mock.ShardCoordinatorMock{} @@ -77,10 +82,10 @@ func createDefaultWorkerArgs(appStatusHandler core.AppStatusHandler) *spos.Worke return nil }, } - syncTimerMock := &mock.SyncTimerMock{} + syncTimerMock := &consensusMocks.SyncTimerMock{} hasher := &hashingMocks.HasherMock{} blsService, _ := bls.NewConsensusService() - poolAdder := testscommon.NewCacherMock() + poolAdder := cache.NewCacherMock() scheduledProcessorArgs := spos.ScheduledProcessorWrapperArgs{ SyncTimer: syncTimerMock, @@ -90,6 +95,7 @@ func createDefaultWorkerArgs(appStatusHandler core.AppStatusHandler) *spos.Worke scheduledProcessor, _ := spos.NewScheduledProcessorWrapper(scheduledProcessorArgs) peerSigHandler := &mock.PeerSignatureHandler{Signer: singleSignerMock, KeyGen: keyGeneratorMock} + workerArgs := &spos.WorkerArgs{ ConsensusService: blsService, BlockChain: blockchainMock, @@ -105,8 +111,8 @@ func createDefaultWorkerArgs(appStatusHandler core.AppStatusHandler) *spos.Worke ShardCoordinator: shardCoordinatorMock, PeerSignatureHandler: peerSigHandler, SyncTimer: syncTimerMock, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, - HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, + HeaderSigVerifier: &consensusMocks.HeaderSigVerifierMock{}, + HeaderIntegrityVerifier: &testscommon.HeaderVersionHandlerStub{}, ChainID: chainID, NetworkShardingCollector: &p2pmocks.NetworkShardingCollectorStub{}, AntifloodHandler: createMockP2PAntifloodHandler(), @@ -116,6 +122,8 @@ func createDefaultWorkerArgs(appStatusHandler core.AppStatusHandler) *spos.Worke AppStatusHandler: appStatusHandler, NodeRedundancyHandler: &mock.NodeRedundancyHandlerStub{}, PeerBlacklistHandler: &mock.PeerBlacklistHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + InvalidSignersCache: &consensusMocks.InvalidSignersCacheMock{}, } return workerArgs @@ -136,11 +144,13 @@ func initWorker(appStatusHandler core.AppStatusHandler) *spos.Worker { workerArgs := createDefaultWorkerArgs(appStatusHandler) sposWorker, _ := spos.NewWorker(workerArgs) + sposWorker.ConsensusState().SetHeader(&block.HeaderV2{}) + return sposWorker } -func initRoundHandlerMock() *mock.RoundHandlerMock { - return &mock.RoundHandlerMock{ +func initRoundHandlerMock() *consensusMocks.RoundHandlerMock { + return &consensusMocks.RoundHandlerMock{ RoundIndex: 0, TimeStampCalled: func() time.Time { return time.Unix(0, 0) @@ -370,6 +380,28 @@ func TestWorker_NewWorkerNodeRedundancyHandlerShouldFail(t *testing.T) { assert.Equal(t, spos.ErrNilNodeRedundancyHandler, err) } +func TestWorker_NewWorkerPoolEnableEpochsHandlerNilShouldFail(t *testing.T) { + t.Parallel() + + workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) + workerArgs.EnableEpochsHandler = nil + wrk, err := spos.NewWorker(workerArgs) + + assert.Nil(t, wrk) + assert.Equal(t, spos.ErrNilEnableEpochsHandler, err) +} + +func TestWorker_NewWorkerPoolInvalidSignersCacheNilShouldFail(t *testing.T) { + t.Parallel() + + workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) + workerArgs.InvalidSignersCache = nil + wrk, err := spos.NewWorker(workerArgs) + + assert.Nil(t, wrk) + assert.Equal(t, spos.ErrNilInvalidSignersCache, err) +} + func TestWorker_NewWorkerShouldWork(t *testing.T) { t.Parallel() @@ -402,8 +434,9 @@ func TestWorker_ProcessReceivedMessageShouldErrIfFloodIsDetectedOnTopic(t *testi TopicField: "topic1", SignatureField: []byte("signature"), } - err := wrk.ProcessReceivedMessage(msg, "peer", &p2pmocks.MessengerStub{}) + msgID, err := wrk.ProcessReceivedMessage(msg, "peer", &p2pmocks.MessengerStub{}) assert.Equal(t, expectedErr, err) + assert.Nil(t, msgID) } func TestWorker_ReceivedSyncStateShouldNotSendOnChannelWhenInputIsFalse(t *testing.T) { @@ -466,7 +499,7 @@ func TestWorker_AddReceivedMessageCallShouldWork(t *testing.T) { assert.Equal(t, 1, len(receivedMessageCalls)) assert.NotNil(t, receivedMessageCalls[bls.MtBlockBody]) - assert.True(t, receivedMessageCalls[bls.MtBlockBody](context.Background(), nil)) + assert.True(t, receivedMessageCalls[bls.MtBlockBody][0](context.Background(), nil)) } func TestWorker_RemoveAllReceivedMessageCallsShouldWork(t *testing.T) { @@ -480,7 +513,7 @@ func TestWorker_RemoveAllReceivedMessageCallsShouldWork(t *testing.T) { assert.Equal(t, 1, len(receivedMessageCalls)) assert.NotNil(t, receivedMessageCalls[bls.MtBlockBody]) - assert.True(t, receivedMessageCalls[bls.MtBlockBody](context.Background(), nil)) + assert.True(t, receivedMessageCalls[bls.MtBlockBody][0](context.Background(), nil)) wrk.RemoveAllReceivedMessagesCalls() receivedMessageCalls = wrk.ReceivedMessagesCalls() @@ -517,35 +550,37 @@ func TestWorker_ProcessReceivedMessageTxBlockBodyShouldRetNil(t *testing.T) { PeerField: currentPid, SignatureField: []byte("signature"), } - err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) - + msgID, err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Nil(t, err) + assert.Len(t, msgID, 0) } func TestWorker_ProcessReceivedMessageNilMessageShouldErr(t *testing.T) { t.Parallel() wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) - err := wrk.ProcessReceivedMessage(nil, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := wrk.ProcessReceivedMessage(nil, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) assert.Equal(t, 0, len(wrk.ReceivedMessages()[bls.MtBlockBody])) assert.Equal(t, spos.ErrNilMessage, err) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedMessageNilMessageDataFieldShouldErr(t *testing.T) { t.Parallel() wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) - err := wrk.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{}, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := wrk.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{}, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) assert.Equal(t, 0, len(wrk.ReceivedMessages()[bls.MtBlockBody])) assert.Equal(t, spos.ErrNilDataToProcess, err) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedMessageEmptySignatureFieldShouldErr(t *testing.T) { t.Parallel() wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) - err := wrk.ProcessReceivedMessage( + msgID, err := wrk.ProcessReceivedMessage( &p2pmocks.P2PMessageMock{ DataField: []byte("data field"), }, @@ -556,6 +591,7 @@ func TestWorker_ProcessReceivedMessageEmptySignatureFieldShouldErr(t *testing.T) assert.Equal(t, 0, len(wrk.ReceivedMessages()[bls.MtBlockBody])) assert.Equal(t, spos.ErrNilSignatureOnP2PMessage, err) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedMessageRedundancyNodeShouldResetInactivityIfNeeded(t *testing.T) { @@ -572,7 +608,7 @@ func TestWorker_ProcessReceivedMessageRedundancyNodeShouldResetInactivityIfNeede } wrk.SetNodeRedundancyHandler(nodeRedundancyMock) buff, _ := wrk.Marshalizer().Marshal(&consensus.Message{}) - _ = wrk.ProcessReceivedMessage( + _, _ = wrk.ProcessReceivedMessage( &p2pmocks.P2PMessageMock{ DataField: buff, SignatureField: []byte("signature"), @@ -606,7 +642,7 @@ func TestWorker_ProcessReceivedMessageNodeNotInEligibleListShouldErr(t *testing. nil, ) buff, _ := wrk.Marshalizer().Marshal(cnsMsg) - err := wrk.ProcessReceivedMessage( + msgID, err := wrk.ProcessReceivedMessage( &p2pmocks.P2PMessageMock{ DataField: buff, SignatureField: []byte("signature"), @@ -618,6 +654,7 @@ func TestWorker_ProcessReceivedMessageNodeNotInEligibleListShouldErr(t *testing. assert.Equal(t, 0, len(wrk.ReceivedMessages()[bls.MtBlockBody])) assert.True(t, errors.Is(err, spos.ErrNodeIsNotInEligibleList)) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedMessageComputeReceivedProposedBlockMetric(t *testing.T) { @@ -765,7 +802,7 @@ func testWorkerProcessReceivedMessageComputeReceivedProposedBlockMetric( }, }) - wrk.SetRoundHandler(&mock.RoundHandlerMock{ + wrk.SetRoundHandler(&consensusMocks.RoundHandlerMock{ RoundIndex: 0, TimeDurationCalled: func() time.Duration { return roundDuration @@ -793,7 +830,7 @@ func testWorkerProcessReceivedMessageComputeReceivedProposedBlockMetric( nil, nil, hdrStr, - []byte(wrk.ConsensusState().ConsensusGroup()[0]), + []byte(wrk.ConsensusState().Leader()), signature, int(bls.MtBlockHeader), 0, @@ -813,7 +850,7 @@ func testWorkerProcessReceivedMessageComputeReceivedProposedBlockMetric( PeerField: currentPid, SignatureField: []byte("signature"), } - _ = wrk.ProcessReceivedMessage(msg, "", &p2pmocks.MessengerStub{}) + _, _ = wrk.ProcessReceivedMessage(msg, "", &p2pmocks.MessengerStub{}) return receivedValue, redundancyReason, redundancyStatus } @@ -841,7 +878,7 @@ func TestWorker_ProcessReceivedMessageInconsistentChainIDInConsensusMessageShoul nil, ) buff, _ := wrk.Marshalizer().Marshal(cnsMsg) - err := wrk.ProcessReceivedMessage( + msgID, err := wrk.ProcessReceivedMessage( &p2pmocks.P2PMessageMock{ DataField: buff, SignatureField: []byte("signature"), @@ -851,6 +888,7 @@ func TestWorker_ProcessReceivedMessageInconsistentChainIDInConsensusMessageShoul ) assert.True(t, errors.Is(err, spos.ErrInvalidChainID)) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedMessageTypeInvalidShouldErr(t *testing.T) { @@ -875,7 +913,7 @@ func TestWorker_ProcessReceivedMessageTypeInvalidShouldErr(t *testing.T) { nil, ) buff, _ := wrk.Marshalizer().Marshal(cnsMsg) - err := wrk.ProcessReceivedMessage( + msgID, err := wrk.ProcessReceivedMessage( &p2pmocks.P2PMessageMock{ DataField: buff, SignatureField: []byte("signature"), @@ -887,6 +925,7 @@ func TestWorker_ProcessReceivedMessageTypeInvalidShouldErr(t *testing.T) { assert.Equal(t, 0, len(wrk.ReceivedMessages()[666])) assert.True(t, errors.Is(err, spos.ErrInvalidMessageType), err) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedHeaderHashSizeInvalidShouldErr(t *testing.T) { @@ -911,7 +950,7 @@ func TestWorker_ProcessReceivedHeaderHashSizeInvalidShouldErr(t *testing.T) { nil, ) buff, _ := wrk.Marshalizer().Marshal(cnsMsg) - err := wrk.ProcessReceivedMessage( + msgID, err := wrk.ProcessReceivedMessage( &p2pmocks.P2PMessageMock{ DataField: buff, SignatureField: []byte("signature"), @@ -923,6 +962,7 @@ func TestWorker_ProcessReceivedHeaderHashSizeInvalidShouldErr(t *testing.T) { assert.Equal(t, 0, len(wrk.ReceivedMessages()[bls.MtBlockBody])) assert.True(t, errors.Is(err, spos.ErrInvalidHeaderHashSize), err) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedMessageForFutureRoundShouldErr(t *testing.T) { @@ -947,7 +987,7 @@ func TestWorker_ProcessReceivedMessageForFutureRoundShouldErr(t *testing.T) { nil, ) buff, _ := wrk.Marshalizer().Marshal(cnsMsg) - err := wrk.ProcessReceivedMessage( + msgID, err := wrk.ProcessReceivedMessage( &p2pmocks.P2PMessageMock{ DataField: buff, SignatureField: []byte("signature"), @@ -959,6 +999,7 @@ func TestWorker_ProcessReceivedMessageForFutureRoundShouldErr(t *testing.T) { assert.Equal(t, 0, len(wrk.ReceivedMessages()[bls.MtBlockBody])) assert.True(t, errors.Is(err, spos.ErrMessageForFutureRound)) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedMessageForPastRoundShouldErr(t *testing.T) { @@ -983,7 +1024,7 @@ func TestWorker_ProcessReceivedMessageForPastRoundShouldErr(t *testing.T) { nil, ) buff, _ := wrk.Marshalizer().Marshal(cnsMsg) - err := wrk.ProcessReceivedMessage( + msgID, err := wrk.ProcessReceivedMessage( &p2pmocks.P2PMessageMock{ DataField: buff, SignatureField: []byte("signature"), @@ -995,6 +1036,7 @@ func TestWorker_ProcessReceivedMessageForPastRoundShouldErr(t *testing.T) { assert.Equal(t, 0, len(wrk.ReceivedMessages()[bls.MtBlockBody])) assert.True(t, errors.Is(err, spos.ErrMessageForPastRound)) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedMessageTypeLimitReachedShouldErr(t *testing.T) { @@ -1025,20 +1067,23 @@ func TestWorker_ProcessReceivedMessageTypeLimitReachedShouldErr(t *testing.T) { SignatureField: []byte("signature"), } - err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) assert.Equal(t, 1, len(wrk.ReceivedMessages()[bls.MtBlockBody])) assert.Nil(t, err) + assert.Len(t, msgID, 0) - err = wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err = wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) assert.Equal(t, 1, len(wrk.ReceivedMessages()[bls.MtBlockBody])) assert.True(t, errors.Is(err, spos.ErrMessageTypeLimitReached)) + assert.Len(t, msgID, 0) - err = wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err = wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) assert.Equal(t, 1, len(wrk.ReceivedMessages()[bls.MtBlockBody])) assert.True(t, errors.Is(err, spos.ErrMessageTypeLimitReached)) + assert.Len(t, msgID, 0) } func TestWorker_ProcessReceivedMessageInvalidSignatureShouldErr(t *testing.T) { @@ -1063,7 +1108,7 @@ func TestWorker_ProcessReceivedMessageInvalidSignatureShouldErr(t *testing.T) { nil, ) buff, _ := wrk.Marshalizer().Marshal(cnsMsg) - err := wrk.ProcessReceivedMessage( + msgID, err := wrk.ProcessReceivedMessage( &p2pmocks.P2PMessageMock{ DataField: buff, SignatureField: []byte("signature"), @@ -1075,6 +1120,7 @@ func TestWorker_ProcessReceivedMessageInvalidSignatureShouldErr(t *testing.T) { assert.Equal(t, 0, len(wrk.ReceivedMessages()[bls.MtBlockBody])) assert.True(t, errors.Is(err, spos.ErrInvalidSignatureSize)) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedMessageReceivedMessageIsFromSelfShouldRetNilAndNotProcess(t *testing.T) { @@ -1104,11 +1150,12 @@ func TestWorker_ProcessReceivedMessageReceivedMessageIsFromSelfShouldRetNilAndNo PeerField: currentPid, SignatureField: []byte("signature"), } - err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) assert.Equal(t, 0, len(wrk.ReceivedMessages()[bls.MtBlockBody])) assert.Nil(t, err) + assert.Len(t, msgID, 0) } func TestWorker_ProcessReceivedMessageWhenRoundIsCanceledShouldRetNilAndNotProcess(t *testing.T) { @@ -1139,11 +1186,12 @@ func TestWorker_ProcessReceivedMessageWhenRoundIsCanceledShouldRetNilAndNotProce PeerField: currentPid, SignatureField: []byte("signature"), } - err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) assert.Equal(t, 0, len(wrk.ReceivedMessages()[bls.MtBlockBody])) assert.Nil(t, err) + assert.Len(t, msgID, 0) } func TestWorker_ProcessReceivedMessageWrongChainIDInProposedBlockShouldError(t *testing.T) { @@ -1185,7 +1233,7 @@ func TestWorker_ProcessReceivedMessageWrongChainIDInProposedBlockShouldError(t * nil, ) buff, _ := wrk.Marshalizer().Marshal(cnsMsg) - err := wrk.ProcessReceivedMessage( + msgID, err := wrk.ProcessReceivedMessage( &p2pmocks.P2PMessageMock{ DataField: buff, SignatureField: []byte("signature"), @@ -1196,6 +1244,7 @@ func TestWorker_ProcessReceivedMessageWrongChainIDInProposedBlockShouldError(t * time.Sleep(time.Second) assert.True(t, errors.Is(err, spos.ErrInvalidChainID)) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedMessageWithABadOriginatorShouldErr(t *testing.T) { @@ -1246,11 +1295,12 @@ func TestWorker_ProcessReceivedMessageWithABadOriginatorShouldErr(t *testing.T) PeerField: "other originator", SignatureField: []byte("signature"), } - err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) assert.Equal(t, 0, len(wrk.ReceivedMessages()[bls.MtBlockHeader])) assert.True(t, errors.Is(err, spos.ErrOriginatorMismatch)) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedMessageWithHeaderAndWrongHash(t *testing.T) { @@ -1258,6 +1308,7 @@ func TestWorker_ProcessReceivedMessageWithHeaderAndWrongHash(t *testing.T) { workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().SetHeader(&block.HeaderV2{}) wrk.SetBlockProcessor( &testscommon.BlockProcessorStub{ @@ -1304,11 +1355,12 @@ func TestWorker_ProcessReceivedMessageWithHeaderAndWrongHash(t *testing.T) { PeerField: currentPid, SignatureField: []byte("signature"), } - err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) assert.Equal(t, 0, len(wrk.ReceivedMessages()[bls.MtBlockHeader])) assert.ErrorIs(t, err, spos.ErrWrongHashForHeader) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedMessageOkValsShouldWork(t *testing.T) { @@ -1327,6 +1379,7 @@ func TestWorker_ProcessReceivedMessageOkValsShouldWork(t *testing.T) { }, } wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().SetHeader(&block.HeaderV2{}) wrk.SetBlockProcessor( &testscommon.BlockProcessorStub{ @@ -1373,12 +1426,13 @@ func TestWorker_ProcessReceivedMessageOkValsShouldWork(t *testing.T) { PeerField: currentPid, SignatureField: []byte("signature"), } - err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := wrk.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) assert.Equal(t, 1, len(wrk.ReceivedMessages()[bls.MtBlockHeader])) assert.Nil(t, err) assert.True(t, wasUpdatePeerIDInfoCalled) + assert.Len(t, msgID, 0) } func TestWorker_CheckSelfStateShouldErrMessageFromItself(t *testing.T) { @@ -1671,7 +1725,7 @@ func TestWorker_CheckChannelsShouldWork(t *testing.T) { t.Parallel() wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) wrk.StartWorking() - wrk.SetReceivedMessagesCalls(bls.MtBlockHeader, func(ctx context.Context, cnsMsg *consensus.Message) bool { + wrk.AppendReceivedMessagesCalls(bls.MtBlockHeader, func(ctx context.Context, cnsMsg *consensus.Message) bool { _ = wrk.ConsensusState().SetJobDone(wrk.ConsensusState().ConsensusGroup()[0], bls.SrBlock, true) return true }) @@ -1709,11 +1763,209 @@ func TestWorker_CheckChannelsShouldWork(t *testing.T) { _ = wrk.Close() } +func TestWorker_ConvertHeaderToConsensusMessage(t *testing.T) { + t.Parallel() + + t.Run("nil header should error", func(t *testing.T) { + wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) + _, err := wrk.ConvertHeaderToConsensusMessage(nil) + require.Equal(t, spos.ErrInvalidHeader, err) + }) + t.Run("valid header v2 should not error", func(t *testing.T) { + wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) + marshaller := wrk.Marshalizer() + hdr := &block.HeaderV2{ + Header: &block.Header{ + Round: 100, + }, + } + + hdrStr, _ := marshaller.Marshal(hdr) + expectedConsensusMsg := &consensus.Message{ + Header: hdrStr, + MsgType: int64(bls.MtBlockHeader), + RoundIndex: 100, + } + + message, err := wrk.ConvertHeaderToConsensusMessage(hdr) + require.Nil(t, err) + require.Equal(t, expectedConsensusMsg, message) + }) + t.Run("valid header metaHeader should not error", func(t *testing.T) { + wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) + marshaller := wrk.Marshalizer() + hdr := &block.MetaBlock{ + Round: 100, + } + + hdrStr, _ := marshaller.Marshal(hdr) + expectedConsensusMsg := &consensus.Message{ + Header: hdrStr, + MsgType: int64(bls.MtBlockHeader), + RoundIndex: 100, + } + + message, err := wrk.ConvertHeaderToConsensusMessage(hdr) + require.Nil(t, err) + require.Equal(t, expectedConsensusMsg, message) + }) +} + +func TestWorker_StoredHeadersExecution(t *testing.T) { + t.Parallel() + + hdr := &block.HeaderV2{ + Header: &block.Header{ + Round: 100, + }, + } + + t.Run("Test stored headers before current round advances to same round should not finalize round", func(t *testing.T) { + wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) + wrk.StartWorking() + wrk.AddReceivedHeaderHandler(func(handler data.HeaderHandler) { + _ = wrk.ConsensusState().SetJobDone(wrk.ConsensusState().ConsensusGroup()[0], bls.SrBlock, true) + }) + + roundIndex := &atomic.Int64{} + roundIndex.Store(99) + roundHandler := &consensusMocks.RoundHandlerMock{ + IndexCalled: func() int64 { + return roundIndex.Load() + }, + } + wrk.SetRoundHandler(roundHandler) + wrk.ConsensusState().SetRoundIndex(99) + cnsGroup := wrk.ConsensusState().ConsensusGroup() + + wrk.BlockProcessor().(*testscommon.BlockProcessorStub).DecodeBlockHeaderCalled = func(dta []byte) data.HeaderHandler { + return hdr + } + wrk.SetEnableEpochsHandler(&enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return true + }, + }) + + wrk.ConsensusState().SetStatus(bls.SrStartRound, spos.SsFinished) + wrk.AddFutureHeaderToProcessIfNeeded(hdr) + time.Sleep(200 * time.Millisecond) + wrk.ExecuteStoredMessages() + time.Sleep(200 * time.Millisecond) + + isBlockJobDone, err := wrk.ConsensusState().JobDone(cnsGroup[0], bls.SrBlock) + + assert.Nil(t, err) + assert.False(t, isBlockJobDone) + + _ = wrk.Close() + }) + t.Run("Test stored headers should finalize round after roundIndex advances", func(t *testing.T) { + wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) + wrk.StartWorking() + wrk.AddReceivedHeaderHandler(func(handler data.HeaderHandler) { + _ = wrk.ConsensusState().SetJobDone(wrk.ConsensusState().ConsensusGroup()[0], bls.SrBlock, true) + }) + + roundIndex := &atomic.Int64{} + roundIndex.Store(99) + roundHandler := &consensusMocks.RoundHandlerMock{ + IndexCalled: func() int64 { + return roundIndex.Load() + }, + } + wrk.SetRoundHandler(roundHandler) + + wrk.ConsensusState().SetRoundIndex(99) + cnsGroup := wrk.ConsensusState().ConsensusGroup() + + wrk.BlockProcessor().(*testscommon.BlockProcessorStub).DecodeBlockHeaderCalled = func(dta []byte) data.HeaderHandler { + return hdr + } + wrk.SetEnableEpochsHandler(&enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return true + }, + }) + + wrk.ConsensusState().SetStatus(bls.SrStartRound, spos.SsFinished) + wrk.AddFutureHeaderToProcessIfNeeded(hdr) + time.Sleep(200 * time.Millisecond) + roundIndex.Store(100) + wrk.ConsensusState().SetRoundIndex(100) + wrk.ExecuteStoredMessages() + time.Sleep(200 * time.Millisecond) + + isBlockJobDone, err := wrk.ConsensusState().JobDone(cnsGroup[0], bls.SrBlock) + + assert.Nil(t, err) + assert.True(t, isBlockJobDone) + + _ = wrk.Close() + }) + t.Run("Test stored meta headers should finalize round after roundIndex advances", func(t *testing.T) { + hdr := &block.MetaBlock{ + Round: 100, + } + + wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) + wrk.StartWorking() + wrk.AddReceivedHeaderHandler(func(handler data.HeaderHandler) { + _ = wrk.ConsensusState().SetJobDone(wrk.ConsensusState().ConsensusGroup()[0], bls.SrBlock, true) + }) + + roundIndex := &atomic.Int64{} + roundIndex.Store(99) + roundHandler := &consensusMocks.RoundHandlerMock{ + IndexCalled: func() int64 { + return roundIndex.Load() + }, + } + wrk.SetRoundHandler(roundHandler) + + wrk.ConsensusState().SetRoundIndex(99) + cnsGroup := wrk.ConsensusState().ConsensusGroup() + + wrk.BlockProcessor().(*testscommon.BlockProcessorStub).DecodeBlockHeaderCalled = func(dta []byte) data.HeaderHandler { + return hdr + } + wrk.SetEnableEpochsHandler(&enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return true + }, + }) + + wrk.ConsensusState().SetStatus(bls.SrStartRound, spos.SsFinished) + wrk.AddFutureHeaderToProcessIfNeeded(hdr) + time.Sleep(200 * time.Millisecond) + roundIndex.Store(100) + wrk.ConsensusState().SetRoundIndex(100) + wrk.ExecuteStoredMessages() + time.Sleep(200 * time.Millisecond) + + isBlockJobDone, err := wrk.ConsensusState().JobDone(cnsGroup[0], bls.SrBlock) + + assert.Nil(t, err) + assert.True(t, isBlockJobDone) + + _ = wrk.Close() + }) +} + func TestWorker_ExtendShouldReturnWhenRoundIsCanceled(t *testing.T) { t.Parallel() wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) executed := false - bootstrapperMock := &mock.BootstrapperStub{ + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{ GetNodeStateCalled: func() common.NodeState { return common.NsNotSynchronized }, @@ -1733,7 +1985,7 @@ func TestWorker_ExtendShouldReturnWhenGetNodeStateNotReturnSynchronized(t *testi t.Parallel() wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) executed := false - bootstrapperMock := &mock.BootstrapperStub{ + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{ GetNodeStateCalled: func() common.NodeState { return common.NsNotSynchronized }, @@ -1752,14 +2004,14 @@ func TestWorker_ExtendShouldReturnWhenCreateEmptyBlockFail(t *testing.T) { t.Parallel() wrk := *initWorker(&statusHandlerMock.AppStatusHandlerStub{}) executed := false - bmm := &mock.BroadcastMessengerMock{ + bmm := &consensusMocks.BroadcastMessengerMock{ BroadcastBlockCalled: func(handler data.BodyHandler, handler2 data.HeaderHandler) error { executed = true return nil }, } wrk.SetBroadcastMessenger(bmm) - bootstrapperMock := &mock.BootstrapperStub{ + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{ CreateAndCommitEmptyBlockCalled: func(shardForCurrentNode uint32) (data.BodyHandler, data.HeaderHandler, error) { return nil, nil, errors.New("error") }} @@ -1863,13 +2115,14 @@ func TestWorker_ProcessReceivedMessageWrongHeaderShouldErr(t *testing.T) { t.Parallel() workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) - headerSigVerifier := &mock.HeaderSigVerifierStub{} + headerSigVerifier := &consensusMocks.HeaderSigVerifierMock{} headerSigVerifier.VerifyRandSeedCalled = func(header data.HeaderHandler) error { return process.ErrRandSeedDoesNotMatch } workerArgs.HeaderSigVerifier = headerSigVerifier wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().SetHeader(&block.HeaderV2{}) hdr := &block.Header{} hdr.Nonce = 1 @@ -1899,8 +2152,9 @@ func TestWorker_ProcessReceivedMessageWrongHeaderShouldErr(t *testing.T) { PeerField: currentPid, SignatureField: []byte("signature"), } - err := wrk.ProcessReceivedMessage(msg, "", &p2pmocks.MessengerStub{}) + msgID, err := wrk.ProcessReceivedMessage(msg, "", &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, spos.ErrInvalidHeader)) + assert.Nil(t, msgID) } func TestWorker_ProcessReceivedMessageWithSignature(t *testing.T) { @@ -1911,6 +2165,7 @@ func TestWorker_ProcessReceivedMessageWithSignature(t *testing.T) { workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().SetHeader(&block.HeaderV2{}) hdr := &block.Header{} hdr.Nonce = 1 @@ -1944,11 +2199,198 @@ func TestWorker_ProcessReceivedMessageWithSignature(t *testing.T) { PeerField: currentPid, SignatureField: []byte("signature"), } - err = wrk.ProcessReceivedMessage(msg, "", &p2pmocks.MessengerStub{}) + msgID, err := wrk.ProcessReceivedMessage(msg, "", &p2pmocks.MessengerStub{}) assert.Nil(t, err) + assert.Len(t, msgID, 0) p2pMsgWithSignature, ok := wrk.ConsensusState().GetMessageWithSignature(string(pubKey)) require.True(t, ok) require.Equal(t, msg, p2pMsgWithSignature) }) } + +func TestWorker_ProcessReceivedMessageWithInvalidSigners(t *testing.T) { + t.Parallel() + + workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) + cntCheckKnownInvalidSignersCalled := 0 + workerArgs.InvalidSignersCache = &consensusMocks.InvalidSignersCacheMock{ + CheckKnownInvalidSignersCalled: func(headerHash []byte, invalidSigners []byte) bool { + cntCheckKnownInvalidSignersCalled++ + return cntCheckKnownInvalidSignersCalled > 1 + }, + } + workerArgs.AntifloodHandler = &mock.P2PAntifloodHandlerStub{ + CanProcessMessageCalled: func(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error { + return nil + }, + CanProcessMessagesOnTopicCalled: func(peer core.PeerID, topic string, numMessages uint32, totalSize uint64, sequence []byte) error { + return nil + }, + BlacklistPeerCalled: func(peer core.PeerID, reason string, duration time.Duration) { + require.Fail(t, "should have not been called") + }, + } + workerArgs.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return true + }, + } + wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().SetHeader(&block.HeaderV2{}) + + hdr := &block.Header{} + hdr.Nonce = 1 + hdr.TimeStamp = uint64(wrk.RoundHandler().TimeStamp().Unix()) + hdrStr, _ := mock.MarshalizerMock{}.Marshal(hdr) + hdrHash := (&hashingMocks.HasherMock{}).Compute(string(hdrStr)) + pubKey := []byte(wrk.ConsensusState().ConsensusGroup()[0]) + + invalidSigners := []byte("invalid signers") + cnsMsg := consensus.NewConsensusMessage( + hdrHash, + nil, + nil, + nil, + pubKey, + bytes.Repeat([]byte("a"), SignatureSize), + int(bls.MtInvalidSigners), + 0, + chainID, + nil, + nil, + nil, + currentPid, + invalidSigners, + ) + buff, err := wrk.Marshalizer().Marshal(cnsMsg) + require.Nil(t, err) + + msg := &p2pmocks.P2PMessageMock{ + DataField: buff, + PeerField: currentPid, + SignatureField: []byte("signature"), + } + + // first call should be ok + msgID, err := wrk.ProcessReceivedMessage(msg, "", &p2pmocks.MessengerStub{}) + require.Nil(t, err) + require.Len(t, msgID, 0) + + // reset the received messages to allow a second one of the same type + wrk.ResetConsensusMessages() + + // second call should see this message as already received and return error + msgID, err = wrk.ProcessReceivedMessage(msg, "", &p2pmocks.MessengerStub{}) + require.Equal(t, spos.ErrInvalidSignersAlreadyReceived, err) + require.Nil(t, msgID) +} + +func TestWorker_ReceivedHeader(t *testing.T) { + t.Parallel() + + t.Run("nil header should early exit", func(t *testing.T) { + t.Parallel() + + workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) + wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().SetHeader(&block.HeaderV2{}) + + rcvHeaderHandler := func(header data.HeaderHandler) { + require.Fail(t, "should have not been called") + } + wrk.AddReceivedHeaderHandler(rcvHeaderHandler) + wrk.ReceivedHeader(nil, nil) + }) + t.Run("unprocessable header should early exit", func(t *testing.T) { + t.Parallel() + + workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) + wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().SetHeader(&block.HeaderV2{}) + + rcvHeaderHandler := func(header data.HeaderHandler) { + require.Fail(t, "should have not been called") + } + wrk.AddReceivedHeaderHandler(rcvHeaderHandler) + wrk.ReceivedHeader(&block.Header{ + ShardID: workerArgs.ShardCoordinator.SelfId(), + Round: uint64(workerArgs.RoundHandler.Index() + 1), // should not process this one + }, nil) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + wasSetUInt64ValueCalled := false + setStringValueCnt := 0 + appStatusHandler := &statusHandlerMock.AppStatusHandlerStub{ + SetUInt64ValueHandler: func(key string, value uint64) { + require.Equal(t, common.MetricReceivedProposedBlock, key) + wasSetUInt64ValueCalled = true + }, + SetStringValueHandler: func(key string, value string) { + setStringValueCnt++ + if key != common.MetricRedundancyIsMainActive && + key != common.MetricRedundancyStepInReason { + require.Fail(t, "unexpected key for SetStringValue") + } + }, + } + workerArgs := createDefaultWorkerArgs(appStatusHandler) + workerArgs.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().SetHeader(&block.HeaderV2{}) + + wasHandlerCalled := false + rcvHeaderHandler := func(header data.HeaderHandler) { + wasHandlerCalled = true + } + wrk.AddReceivedHeaderHandler(rcvHeaderHandler) + wrk.ReceivedHeader(&block.Header{ + ShardID: workerArgs.ShardCoordinator.SelfId(), + Round: uint64(workerArgs.RoundHandler.Index()), + }, nil) + require.True(t, wasHandlerCalled) + + wrk.RemoveAllReceivedHeaderHandlers() // coverage only + require.True(t, wasSetUInt64ValueCalled) + require.Equal(t, 2, setStringValueCnt) + }) +} + +func TestWorker_ReceivedProof(t *testing.T) { + t.Parallel() + + t.Run("nil proof should early exit", func(t *testing.T) { + t.Parallel() + + workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) + wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().SetHeader(&block.HeaderV2{}) + + rcvProofHandler := func(proof consensus.ProofHandler) { + require.Fail(t, "should have not been called") + } + wrk.AddReceivedProofHandler(rcvProofHandler) + wrk.ReceivedProof(nil) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + workerArgs := createDefaultWorkerArgs(&statusHandlerMock.AppStatusHandlerStub{}) + wrk, _ := spos.NewWorker(workerArgs) + wrk.ConsensusState().SetHeader(&block.HeaderV2{}) + + wasHandlerCalled := false + rcvProofHandler := func(proof consensus.ProofHandler) { + wasHandlerCalled = true + } + wrk.AddReceivedProofHandler(rcvProofHandler) + wrk.ReceivedProof(&block.HeaderProof{}) + require.True(t, wasHandlerCalled) + }) +} diff --git a/dataRetriever/blockchain/baseBlockchain_test.go b/dataRetriever/blockchain/baseBlockchain_test.go index 3f6121b6a07..69a49304db0 100644 --- a/dataRetriever/blockchain/baseBlockchain_test.go +++ b/dataRetriever/blockchain/baseBlockchain_test.go @@ -8,6 +8,8 @@ import ( ) func TestBaseBlockchain_SetAndGetSetFinalBlockInfo(t *testing.T) { + t.Parallel() + base := &baseBlockChain{ appStatusHandler: &mock.AppStatusHandlerStub{}, finalBlockInfo: &blockInfo{}, @@ -26,6 +28,8 @@ func TestBaseBlockchain_SetAndGetSetFinalBlockInfo(t *testing.T) { } func TestBaseBlockchain_SetAndGetSetFinalBlockInfoWorksWithNilValues(t *testing.T) { + t.Parallel() + base := &baseBlockChain{ appStatusHandler: &mock.AppStatusHandlerStub{}, finalBlockInfo: &blockInfo{}, diff --git a/dataRetriever/dataPool/dataPool.go b/dataRetriever/dataPool/dataPool.go index 67b55cbfaee..be759b15b43 100644 --- a/dataRetriever/dataPool/dataPool.go +++ b/dataRetriever/dataPool/dataPool.go @@ -26,6 +26,7 @@ type dataPool struct { peerAuthentications storage.Cacher heartbeats storage.Cacher validatorsInfo dataRetriever.ShardedDataCacherNotifier + proofs dataRetriever.ProofsPool } // DataPoolArgs represents the data pool's constructor structure @@ -44,6 +45,7 @@ type DataPoolArgs struct { PeerAuthentications storage.Cacher Heartbeats storage.Cacher ValidatorsInfo dataRetriever.ShardedDataCacherNotifier + Proofs dataRetriever.ProofsPool } // NewDataPool creates a data pools holder object @@ -90,6 +92,9 @@ func NewDataPool(args DataPoolArgs) (*dataPool, error) { if check.IfNil(args.ValidatorsInfo) { return nil, dataRetriever.ErrNilValidatorInfoPool } + if check.IfNil(args.Proofs) { + return nil, dataRetriever.ErrNilProofsPool + } return &dataPool{ transactions: args.Transactions, @@ -106,6 +111,7 @@ func NewDataPool(args DataPoolArgs) (*dataPool, error) { peerAuthentications: args.PeerAuthentications, heartbeats: args.Heartbeats, validatorsInfo: args.ValidatorsInfo, + proofs: args.Proofs, }, nil } @@ -179,6 +185,11 @@ func (dp *dataPool) ValidatorsInfo() dataRetriever.ShardedDataCacherNotifier { return dp.validatorsInfo } +// Proofs returns the holder for equivalent proofs +func (dp *dataPool) Proofs() dataRetriever.ProofsPool { + return dp.proofs +} + // Close closes all the components func (dp *dataPool) Close() error { var lastError error diff --git a/dataRetriever/dataPool/dataPool_test.go b/dataRetriever/dataPool/dataPool_test.go index b948b7f2d44..9a8f17181e3 100644 --- a/dataRetriever/dataPool/dataPool_test.go +++ b/dataRetriever/dataPool/dataPool_test.go @@ -8,11 +8,14 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever/dataPool" "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -//------- NewDataPool +// ------- NewDataPool func createMockDataPoolArgs() dataPool.DataPoolArgs { return dataPool.DataPoolArgs{ @@ -20,16 +23,17 @@ func createMockDataPoolArgs() dataPool.DataPoolArgs { UnsignedTransactions: testscommon.NewShardedDataStub(), RewardTransactions: testscommon.NewShardedDataStub(), Headers: &mock.HeadersCacherStub{}, - MiniBlocks: testscommon.NewCacherStub(), - PeerChangesBlocks: testscommon.NewCacherStub(), - TrieNodes: testscommon.NewCacherStub(), - TrieNodesChunks: testscommon.NewCacherStub(), + MiniBlocks: cache.NewCacherStub(), + PeerChangesBlocks: cache.NewCacherStub(), + TrieNodes: cache.NewCacherStub(), + TrieNodesChunks: cache.NewCacherStub(), CurrentBlockTransactions: &mock.TxForCurrentBlockStub{}, CurrentEpochValidatorInfo: &mock.ValidatorInfoForCurrentEpochStub{}, - SmartContracts: testscommon.NewCacherStub(), - PeerAuthentications: testscommon.NewCacherStub(), - Heartbeats: testscommon.NewCacherStub(), + SmartContracts: cache.NewCacherStub(), + PeerAuthentications: cache.NewCacherStub(), + Heartbeats: cache.NewCacherStub(), ValidatorsInfo: testscommon.NewShardedDataStub(), + Proofs: &dataRetrieverMocks.ProofsPoolMock{}, } } @@ -195,7 +199,7 @@ func TestNewDataPool_OkValsShouldWork(t *testing.T) { assert.Nil(t, err) require.False(t, tdp.IsInterfaceNil()) - //pointer checking + // pointer checking assert.True(t, args.Transactions == tdp.Transactions()) assert.True(t, args.UnsignedTransactions == tdp.UnsignedTransactions()) assert.True(t, args.RewardTransactions == tdp.RewardTransactions()) @@ -220,7 +224,7 @@ func TestNewDataPool_Close(t *testing.T) { t.Parallel() args := createMockDataPoolArgs() - args.TrieNodes = &testscommon.CacherStub{ + args.TrieNodes = &cache.CacherStub{ CloseCalled: func() error { return expectedErr }, @@ -234,7 +238,7 @@ func TestNewDataPool_Close(t *testing.T) { t.Parallel() args := createMockDataPoolArgs() - args.PeerAuthentications = &testscommon.CacherStub{ + args.PeerAuthentications = &cache.CacherStub{ CloseCalled: func() error { return expectedErr }, @@ -251,13 +255,13 @@ func TestNewDataPool_Close(t *testing.T) { paExpectedErr := errors.New("pa expected error") args := createMockDataPoolArgs() tnCalled, paCalled := false, false - args.TrieNodes = &testscommon.CacherStub{ + args.TrieNodes = &cache.CacherStub{ CloseCalled: func() error { tnCalled = true return tnExpectedErr }, } - args.PeerAuthentications = &testscommon.CacherStub{ + args.PeerAuthentications = &cache.CacherStub{ CloseCalled: func() error { paCalled = true return paExpectedErr @@ -275,13 +279,13 @@ func TestNewDataPool_Close(t *testing.T) { args := createMockDataPoolArgs() tnCalled, paCalled := false, false - args.TrieNodes = &testscommon.CacherStub{ + args.TrieNodes = &cache.CacherStub{ CloseCalled: func() error { tnCalled = true return nil }, } - args.PeerAuthentications = &testscommon.CacherStub{ + args.PeerAuthentications = &cache.CacherStub{ CloseCalled: func() error { paCalled = true return nil diff --git a/dataRetriever/dataPool/headersCache/headersPool.go b/dataRetriever/dataPool/headersCache/headersPool.go index cf824cc6e10..09a68e10c2b 100644 --- a/dataRetriever/dataPool/headersCache/headersPool.go +++ b/dataRetriever/dataPool/headersCache/headersPool.go @@ -5,9 +5,10 @@ import ( "sync" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" - "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("dataRetriever/headersCache") @@ -64,6 +65,7 @@ func (pool *headersPool) AddHeader(headerHash []byte, header data.HeaderHandler) added := pool.cache.addHeader(headerHash, header) if added { + log.Trace("added header to pool", "header shard", header.GetShardID(), "header nonce", header.GetNonce(), "header hash", headerHash) pool.callAddedDataHandlers(header, headerHash) } } diff --git a/dataRetriever/dataPool/proofsCache/errors.go b/dataRetriever/dataPool/proofsCache/errors.go new file mode 100644 index 00000000000..06cd0604274 --- /dev/null +++ b/dataRetriever/dataPool/proofsCache/errors.go @@ -0,0 +1,6 @@ +package proofscache + +import "errors" + +// ErrMissingProof signals that the proof is missing +var ErrMissingProof = errors.New("missing proof") diff --git a/dataRetriever/dataPool/proofsCache/export_test.go b/dataRetriever/dataPool/proofsCache/export_test.go new file mode 100644 index 00000000000..f6b0b007405 --- /dev/null +++ b/dataRetriever/dataPool/proofsCache/export_test.go @@ -0,0 +1,34 @@ +package proofscache + +import "github.com/multiversx/mx-chain-core-go/data" + +// NewProofsCache - +func NewProofsCache(bucketSize int) *proofsCache { + return newProofsCache(bucketSize) +} + +// HeadBucketSize - +func (pc *proofsCache) FullProofsByNonceSize() int { + size := 0 + + for _, bucket := range pc.proofsByNonceBuckets { + size += bucket.size() + } + + return size +} + +// ProofsByHashSize - +func (pc *proofsCache) ProofsByHashSize() int { + return len(pc.proofsByHash) +} + +// AddProof - +func (pc *proofsCache) AddProof(proof data.HeaderProofHandler) { + pc.addProof(proof) +} + +// CleanupProofsBehindNonce - +func (pc *proofsCache) CleanupProofsBehindNonce(nonce uint64) { + pc.cleanupProofsBehindNonce(nonce) +} diff --git a/dataRetriever/dataPool/proofsCache/proofsBucket.go b/dataRetriever/dataPool/proofsCache/proofsBucket.go new file mode 100644 index 00000000000..91b5815f440 --- /dev/null +++ b/dataRetriever/dataPool/proofsCache/proofsBucket.go @@ -0,0 +1,26 @@ +package proofscache + +import "github.com/multiversx/mx-chain-core-go/data" + +type proofNonceBucket struct { + maxNonce uint64 + proofsByNonce map[uint64]string +} + +func newProofBucket() *proofNonceBucket { + return &proofNonceBucket{ + proofsByNonce: make(map[uint64]string), + } +} + +func (p *proofNonceBucket) size() int { + return len(p.proofsByNonce) +} + +func (p *proofNonceBucket) insert(proof data.HeaderProofHandler) { + p.proofsByNonce[proof.GetHeaderNonce()] = string(proof.GetHeaderHash()) + + if proof.GetHeaderNonce() > p.maxNonce { + p.maxNonce = proof.GetHeaderNonce() + } +} diff --git a/dataRetriever/dataPool/proofsCache/proofsCache.go b/dataRetriever/dataPool/proofsCache/proofsCache.go new file mode 100644 index 00000000000..d885ffe8a41 --- /dev/null +++ b/dataRetriever/dataPool/proofsCache/proofsCache.go @@ -0,0 +1,110 @@ +package proofscache + +import ( + "sync" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" +) + +type proofsCache struct { + mutProofsCache sync.RWMutex + proofsByNonceBuckets map[uint64]*proofNonceBucket + bucketSize uint64 + proofsByHash map[string]data.HeaderProofHandler +} + +func newProofsCache(bucketSize int) *proofsCache { + return &proofsCache{ + proofsByNonceBuckets: make(map[uint64]*proofNonceBucket), + bucketSize: uint64(bucketSize), + proofsByHash: make(map[string]data.HeaderProofHandler), + } +} + +func (pc *proofsCache) getProofByHash(headerHash []byte) (data.HeaderProofHandler, error) { + pc.mutProofsCache.RLock() + defer pc.mutProofsCache.RUnlock() + + proof, ok := pc.proofsByHash[string(headerHash)] + if !ok { + return nil, ErrMissingProof + } + + return proof, nil +} + +func (pc *proofsCache) getProofByNonce(headerNonce uint64) (data.HeaderProofHandler, error) { + pc.mutProofsCache.RLock() + defer pc.mutProofsCache.RUnlock() + + bucketKey := pc.getBucketKey(headerNonce) + bucket, ok := pc.proofsByNonceBuckets[bucketKey] + if !ok { + return nil, ErrMissingProof + } + + proofHash, ok := bucket.proofsByNonce[headerNonce] + if !ok { + return nil, ErrMissingProof + } + + proof, ok := pc.proofsByHash[proofHash] + if !ok { + return nil, ErrMissingProof + } + + return proof, nil +} + +func (pc *proofsCache) addProof(proof data.HeaderProofHandler) { + if check.IfNil(proof) { + return + } + + pc.mutProofsCache.Lock() + defer pc.mutProofsCache.Unlock() + + pc.insertProofByNonce(proof) + + pc.proofsByHash[string(proof.GetHeaderHash())] = proof +} + +// getBucketKey will return bucket key as lower bound window value +func (pc *proofsCache) getBucketKey(index uint64) uint64 { + return (index / pc.bucketSize) * pc.bucketSize +} + +func (pc *proofsCache) insertProofByNonce(proof data.HeaderProofHandler) { + bucketKey := pc.getBucketKey(proof.GetHeaderNonce()) + + bucket, ok := pc.proofsByNonceBuckets[bucketKey] + if !ok { + bucket = newProofBucket() + pc.proofsByNonceBuckets[bucketKey] = bucket + } + + bucket.insert(proof) +} + +func (pc *proofsCache) cleanupProofsBehindNonce(nonce uint64) { + if nonce == 0 { + return + } + + pc.mutProofsCache.Lock() + defer pc.mutProofsCache.Unlock() + + for key, bucket := range pc.proofsByNonceBuckets { + if nonce > bucket.maxNonce { + pc.cleanupProofsInBucket(bucket) + delete(pc.proofsByNonceBuckets, key) + } + } +} + +func (pc *proofsCache) cleanupProofsInBucket(bucket *proofNonceBucket) { + for _, headerHash := range bucket.proofsByNonce { + delete(pc.proofsByHash, headerHash) + } +} diff --git a/dataRetriever/dataPool/proofsCache/proofsCache_bench_test.go b/dataRetriever/dataPool/proofsCache/proofsCache_bench_test.go new file mode 100644 index 00000000000..910895b099c --- /dev/null +++ b/dataRetriever/dataPool/proofsCache/proofsCache_bench_test.go @@ -0,0 +1,66 @@ +package proofscache_test + +import ( + "fmt" + "testing" + + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" +) + +func Benchmark_AddProof_Bucket10_Pool1000(b *testing.B) { + benchmarkAddProof(b, 10, 1000) +} + +func Benchmark_AddProof_Bucket100_Pool10000(b *testing.B) { + benchmarkAddProof(b, 100, 10000) +} + +func Benchmark_AddProof_Bucket1000_Pool100000(b *testing.B) { + benchmarkAddProof(b, 1000, 100000) +} + +func benchmarkAddProof(b *testing.B, bucketSize int, nonceRange int) { + pc := proofscache.NewProofsCache(bucketSize) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + proof := generateProof() + nonce := generateRandomNonce(int64(nonceRange)) + + proof.HeaderNonce = nonce + proof.HeaderHash = []byte("hash_" + fmt.Sprintf("%d", nonce)) + b.StartTimer() + + pc.AddProof(proof) + } +} + +func Benchmark_CleanupProofs_Bucket10_Pool1000(b *testing.B) { + benchmarkCleanupProofs(b, 10, 1000) +} + +func Benchmark_CleanupProofs_Bucket100_Pool10000(b *testing.B) { + benchmarkCleanupProofs(b, 100, 10000) +} + +func Benchmark_CleanupProofs_Bucket1000_Pool100000(b *testing.B) { + benchmarkCleanupProofs(b, 1000, 100000) +} + +func benchmarkCleanupProofs(b *testing.B, bucketSize int, nonceRange int) { + pc := proofscache.NewProofsCache(bucketSize) + + for i := uint64(0); i < uint64(nonceRange); i++ { + proof := generateProof() + proof.HeaderNonce = i + proof.HeaderHash = []byte("hash_" + fmt.Sprintf("%d", i)) + + pc.AddProof(proof) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pc.CleanupProofsBehindNonce(uint64(nonceRange)) + } +} diff --git a/dataRetriever/dataPool/proofsCache/proofsCache_test.go b/dataRetriever/dataPool/proofsCache/proofsCache_test.go new file mode 100644 index 00000000000..84bc70d9104 --- /dev/null +++ b/dataRetriever/dataPool/proofsCache/proofsCache_test.go @@ -0,0 +1,152 @@ +package proofscache_test + +import ( + "fmt" + "math/rand" + "sync" + "testing" + + "github.com/multiversx/mx-chain-core-go/data/block" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProofsCache(t *testing.T) { + t.Parallel() + + t.Run("incremental nonces, should cleanup all caches", func(t *testing.T) { + t.Parallel() + + proof0 := &block.HeaderProof{HeaderHash: []byte{0}, HeaderNonce: 0} + proof1 := &block.HeaderProof{HeaderHash: []byte{1}, HeaderNonce: 1} + proof2 := &block.HeaderProof{HeaderHash: []byte{2}, HeaderNonce: 2} + proof3 := &block.HeaderProof{HeaderHash: []byte{3}, HeaderNonce: 3} + proof4 := &block.HeaderProof{HeaderHash: []byte{4}, HeaderNonce: 4} + + pc := proofscache.NewProofsCache(4) + + pc.AddProof(proof0) + pc.AddProof(proof1) + pc.AddProof(proof2) + pc.AddProof(proof3) + + require.Equal(t, 4, pc.FullProofsByNonceSize()) + require.Equal(t, 4, pc.ProofsByHashSize()) + + pc.AddProof(proof4) // added to new head bucket + + require.Equal(t, 5, pc.ProofsByHashSize()) + + pc.CleanupProofsBehindNonce(4) + require.Equal(t, 1, pc.ProofsByHashSize()) + + pc.CleanupProofsBehindNonce(10) + require.Equal(t, 0, pc.ProofsByHashSize()) + }) + + t.Run("non incremental nonces", func(t *testing.T) { + t.Parallel() + + proof0 := &block.HeaderProof{HeaderHash: []byte{0}, HeaderNonce: 0} + proof1 := &block.HeaderProof{HeaderHash: []byte{1}, HeaderNonce: 1} + proof2 := &block.HeaderProof{HeaderHash: []byte{2}, HeaderNonce: 2} + proof3 := &block.HeaderProof{HeaderHash: []byte{3}, HeaderNonce: 3} + proof4 := &block.HeaderProof{HeaderHash: []byte{4}, HeaderNonce: 4} + proof5 := &block.HeaderProof{HeaderHash: []byte{5}, HeaderNonce: 5} + + pc := proofscache.NewProofsCache(4) + + pc.AddProof(proof4) + pc.AddProof(proof1) + pc.AddProof(proof2) + pc.AddProof(proof3) + + require.Equal(t, 4, pc.FullProofsByNonceSize()) + require.Equal(t, 4, pc.ProofsByHashSize()) + + pc.AddProof(proof0) // added to new head bucket + + require.Equal(t, 5, pc.FullProofsByNonceSize()) + require.Equal(t, 5, pc.ProofsByHashSize()) + + pc.CleanupProofsBehindNonce(4) + + // cleanup up head bucket with only one proof + require.Equal(t, 1, pc.ProofsByHashSize()) + + pc.AddProof(proof5) // added to new head bucket + + require.Equal(t, 2, pc.ProofsByHashSize()) + + pc.CleanupProofsBehindNonce(5) // will not remove any bucket + require.Equal(t, 2, pc.FullProofsByNonceSize()) + require.Equal(t, 2, pc.ProofsByHashSize()) + + pc.CleanupProofsBehindNonce(10) + require.Equal(t, 0, pc.ProofsByHashSize()) + }) + + t.Run("shuffled nonces, should cleanup all caches", func(t *testing.T) { + t.Parallel() + + pc := proofscache.NewProofsCache(10) + + nonces := generateShuffledNonces(100) + for _, nonce := range nonces { + proof := generateProof() + proof.HeaderNonce = nonce + proof.HeaderHash = []byte("hash_" + fmt.Sprintf("%d", nonce)) + + pc.AddProof(proof) + } + + require.Equal(t, 100, pc.FullProofsByNonceSize()) + require.Equal(t, 100, pc.ProofsByHashSize()) + + pc.CleanupProofsBehindNonce(100) + require.Equal(t, 0, pc.FullProofsByNonceSize()) + require.Equal(t, 0, pc.ProofsByHashSize()) + }) +} + +func TestProofsCache_Concurrency(t *testing.T) { + t.Parallel() + + pc := proofscache.NewProofsCache(100) + + numOperations := 1000 + + wg := sync.WaitGroup{} + wg.Add(numOperations) + + for i := 0; i < numOperations; i++ { + go func(idx int) { + switch idx % 2 { + case 0: + pc.AddProof(generateProof()) + case 1: + pc.CleanupProofsBehindNonce(generateRandomNonce(100)) + default: + assert.Fail(t, "should have not beed called") + } + + wg.Done() + }(i) + } + + wg.Wait() +} + +func generateShuffledNonces(n int) []uint64 { + nonces := make([]uint64, n) + for i := uint64(0); i < uint64(n); i++ { + nonces[i] = i + } + + rand.Shuffle(len(nonces), func(i, j int) { + nonces[i], nonces[j] = nonces[j], nonces[i] + }) + + return nonces +} diff --git a/dataRetriever/dataPool/proofsCache/proofsPool.go b/dataRetriever/dataPool/proofsCache/proofsPool.go new file mode 100644 index 00000000000..d5c4c4a7629 --- /dev/null +++ b/dataRetriever/dataPool/proofsCache/proofsPool.go @@ -0,0 +1,230 @@ +package proofscache + +import ( + "bytes" + "fmt" + "sync" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + logger "github.com/multiversx/mx-chain-logger-go" +) + +const defaultCleanupNonceDelta = 3 +const defaultBucketSize = 100 + +var log = logger.GetOrCreate("dataRetriever/proofscache") + +type proofsPool struct { + mutCache sync.RWMutex + cache map[uint32]*proofsCache + + mutAddedProofSubscribers sync.RWMutex + addedProofSubscribers []func(headerProof data.HeaderProofHandler) + cleanupNonceDelta uint64 + bucketSize int +} + +// NewProofsPool creates a new proofs pool component +func NewProofsPool(cleanupNonceDelta uint64, bucketSize int) *proofsPool { + if cleanupNonceDelta < defaultCleanupNonceDelta { + log.Debug("proofs pool: using default cleanup nonce delta", "cleanupNonceDelta", defaultCleanupNonceDelta) + cleanupNonceDelta = defaultCleanupNonceDelta + } + if bucketSize < defaultBucketSize { + log.Debug("proofs pool: using default bucket size", "bucketSize", defaultBucketSize) + bucketSize = defaultBucketSize + } + + return &proofsPool{ + cache: make(map[uint32]*proofsCache), + addedProofSubscribers: make([]func(headerProof data.HeaderProofHandler), 0), + cleanupNonceDelta: cleanupNonceDelta, + bucketSize: bucketSize, + } +} + +// UpsertProof will add the provided proof to the pool. If there is already an existing proof, +// it will overwrite it. +func (pp *proofsPool) UpsertProof( + headerProof data.HeaderProofHandler, +) bool { + if check.IfNil(headerProof) { + return false + } + + return pp.addProof(headerProof) +} + +// AddProof will add the provided proof to the pool, if it's not already in the pool. +// It will return true if the proof was added to the pool. +func (pp *proofsPool) AddProof( + headerProof data.HeaderProofHandler, +) bool { + if check.IfNil(headerProof) { + return false + } + + hasProof := pp.HasProof(headerProof.GetHeaderShardId(), headerProof.GetHeaderHash()) + if hasProof { + return false + } + + return pp.addProof(headerProof) +} + +func (pp *proofsPool) addProof( + headerProof data.HeaderProofHandler, +) bool { + shardID := headerProof.GetHeaderShardId() + + pp.mutCache.Lock() + proofsPerShard, ok := pp.cache[shardID] + if !ok { + proofsPerShard = newProofsCache(pp.bucketSize) + pp.cache[shardID] = proofsPerShard + } + pp.mutCache.Unlock() + + log.Debug("added proof to pool", + "header hash", headerProof.GetHeaderHash(), + "epoch", headerProof.GetHeaderEpoch(), + "nonce", headerProof.GetHeaderNonce(), + "shardID", headerProof.GetHeaderShardId(), + "pubKeys bitmap", headerProof.GetPubKeysBitmap(), + "round", headerProof.GetHeaderRound(), + "nonce", headerProof.GetHeaderNonce(), + "isStartOfEpoch", headerProof.GetIsStartOfEpoch(), + ) + + proofsPerShard.addProof(headerProof) + + pp.callAddedProofSubscribers(headerProof) + + return true +} + +// IsProofInPoolEqualTo will check if the provided proof is equal with the already existing proof in the pool +func (pp *proofsPool) IsProofInPoolEqualTo(headerProof data.HeaderProofHandler) bool { + if check.IfNil(headerProof) { + return false + } + + existingProof, err := pp.GetProof(headerProof.GetHeaderShardId(), headerProof.GetHeaderHash()) + if err != nil { + return false + } + + if !bytes.Equal(existingProof.GetAggregatedSignature(), headerProof.GetAggregatedSignature()) { + return false + } + if !bytes.Equal(existingProof.GetPubKeysBitmap(), headerProof.GetPubKeysBitmap()) { + return false + } + + return true +} + +func (pp *proofsPool) callAddedProofSubscribers(headerProof data.HeaderProofHandler) { + pp.mutAddedProofSubscribers.RLock() + defer pp.mutAddedProofSubscribers.RUnlock() + + for _, handler := range pp.addedProofSubscribers { + go handler(headerProof) + } +} + +// CleanupProofsBehindNonce will cleanup proofs from pool based on nonce +func (pp *proofsPool) CleanupProofsBehindNonce(shardID uint32, nonce uint64) error { + if nonce == 0 { + return nil + } + + if nonce <= pp.cleanupNonceDelta { + return nil + } + + nonce -= pp.cleanupNonceDelta + + pp.mutCache.RLock() + proofsPerShard, ok := pp.cache[shardID] + pp.mutCache.RUnlock() + if !ok { + return fmt.Errorf("%w: proofs cache per shard not found, shard ID: %d", ErrMissingProof, shardID) + } + + log.Trace("cleanup proofs behind nonce", + "nonce", nonce, + "shardID", shardID, + ) + + proofsPerShard.cleanupProofsBehindNonce(nonce) + + return nil +} + +// GetProof will get the proof from pool +func (pp *proofsPool) GetProof( + shardID uint32, + headerHash []byte, +) (data.HeaderProofHandler, error) { + if headerHash == nil { + return nil, fmt.Errorf("nil header hash") + } + log.Trace("trying to get proof", + "headerHash", headerHash, + "shardID", shardID, + ) + + pp.mutCache.RLock() + proofsPerShard, ok := pp.cache[shardID] + pp.mutCache.RUnlock() + if !ok { + return nil, fmt.Errorf("%w: proofs cache per shard not found, shard ID: %d", ErrMissingProof, shardID) + } + + return proofsPerShard.getProofByHash(headerHash) +} + +// GetProofByNonce will get the proof from pool for the provided header nonce, searching through all shards +func (pp *proofsPool) GetProofByNonce(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) { + log.Trace("trying to get proof", + "headerNonce", headerNonce, + "shardID", shardID, + ) + + pp.mutCache.RLock() + proofsPerShard, ok := pp.cache[shardID] + pp.mutCache.RUnlock() + if !ok { + return nil, fmt.Errorf("%w: proofs cache per shard not found, shard ID: %d", ErrMissingProof, shardID) + } + + return proofsPerShard.getProofByNonce(headerNonce) +} + +// HasProof will check if there is a proof for the provided hash +func (pp *proofsPool) HasProof( + shardID uint32, + headerHash []byte, +) bool { + _, err := pp.GetProof(shardID, headerHash) + return err == nil +} + +// RegisterHandler registers a new handler to be called when a new data is added +func (pp *proofsPool) RegisterHandler(handler func(headerProof data.HeaderProofHandler)) { + if handler == nil { + log.Error("attempt to register a nil handler to proofs pool") + return + } + + pp.mutAddedProofSubscribers.Lock() + pp.addedProofSubscribers = append(pp.addedProofSubscribers, handler) + pp.mutAddedProofSubscribers.Unlock() +} + +// IsInterfaceNil returns true if there is no value under the interface +func (pp *proofsPool) IsInterfaceNil() bool { + return pp == nil +} diff --git a/dataRetriever/dataPool/proofsCache/proofsPool_test.go b/dataRetriever/dataPool/proofsCache/proofsPool_test.go new file mode 100644 index 00000000000..efec39a85d6 --- /dev/null +++ b/dataRetriever/dataPool/proofsCache/proofsPool_test.go @@ -0,0 +1,336 @@ +package proofscache_test + +import ( + "crypto/rand" + "errors" + "math/big" + "sync" + "sync/atomic" + "testing" + + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const cleanupDelta = 3 +const bucketSize = 100 + +var shardID = uint32(1) + +var proof1 = &block.HeaderProof{ + PubKeysBitmap: []byte("pubKeysBitmap1"), + AggregatedSignature: []byte("aggSig1"), + HeaderHash: []byte("hash1"), + HeaderEpoch: 1, + HeaderNonce: 1, + HeaderShardId: shardID, +} + +var proof2 = &block.HeaderProof{ + PubKeysBitmap: []byte("pubKeysBitmap2"), + AggregatedSignature: []byte("aggSig2"), + HeaderHash: []byte("hash2"), + HeaderEpoch: 1, + HeaderNonce: 2, + HeaderShardId: shardID, +} +var proof3 = &block.HeaderProof{ + PubKeysBitmap: []byte("pubKeysBitmap3"), + AggregatedSignature: []byte("aggSig3"), + HeaderHash: []byte("hash3"), + HeaderEpoch: 1, + HeaderNonce: 3, + HeaderShardId: shardID, +} +var proof4 = &block.HeaderProof{ + PubKeysBitmap: []byte("pubKeysBitmap4"), + AggregatedSignature: []byte("aggSig4"), + HeaderHash: []byte("hash4"), + HeaderEpoch: 1, + HeaderNonce: 4, + HeaderShardId: shardID, +} + +func TestNewProofsPool(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool(cleanupDelta, bucketSize) + require.False(t, pp.IsInterfaceNil()) +} + +func TestProofsPool_ShouldWork(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool(cleanupDelta, bucketSize) + + ok := pp.AddProof(nil) + require.False(t, ok) + + _ = pp.AddProof(proof1) + _ = pp.AddProof(proof2) + _ = pp.AddProof(proof3) + _ = pp.AddProof(proof4) + + ok = pp.AddProof(proof4) + require.False(t, ok) + + proof, err := pp.GetProof(shardID, []byte("hash3")) + require.Nil(t, err) + require.Equal(t, proof3, proof) + proof, err = pp.GetProofByNonce(3, shardID) + require.Nil(t, err) + require.Equal(t, proof3, proof) + + err = pp.CleanupProofsBehindNonce(shardID, 4) + require.Nil(t, err) + + proof, err = pp.GetProof(shardID, []byte("hash3")) + require.Nil(t, err) + require.Equal(t, proof3, proof) + proof, err = pp.GetProofByNonce(3, shardID) + require.Nil(t, err) + require.Equal(t, proof3, proof) + + proof, err = pp.GetProof(shardID, []byte("hash4")) + require.Nil(t, err) + require.Equal(t, proof4, proof) + proof, err = pp.GetProofByNonce(4, shardID) + require.Nil(t, err) + require.Equal(t, proof4, proof) +} + +func TestProofsPool_Upsert(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool(cleanupDelta, bucketSize) + + ok := pp.UpsertProof(nil) + require.False(t, ok) + + ok = pp.UpsertProof(proof1) + require.True(t, ok) + + proof, err := pp.GetProof(shardID, []byte("hash1")) + require.Nil(t, err) + require.NotNil(t, proof) + + require.Equal(t, proof1.GetAggregatedSignature(), proof.GetAggregatedSignature()) + require.Equal(t, proof1.GetPubKeysBitmap(), proof.GetPubKeysBitmap()) + + newProof1 := &block.HeaderProof{ + PubKeysBitmap: []byte("newpubKeysBitmap1"), + AggregatedSignature: []byte("newaggSig1"), + HeaderHash: []byte("hash1"), + HeaderEpoch: 1, + HeaderNonce: 1, + HeaderShardId: shardID, + } + + ok = pp.UpsertProof(newProof1) + require.True(t, ok) + + proof, err = pp.GetProof(shardID, []byte("hash1")) + require.Nil(t, err) + require.NotNil(t, proof) + + require.Equal(t, newProof1.GetAggregatedSignature(), proof.GetAggregatedSignature()) + require.Equal(t, newProof1.GetPubKeysBitmap(), proof.GetPubKeysBitmap()) +} + +func TestProofsPool_IsProofEqual(t *testing.T) { + t.Parallel() + + t.Run("not existing proof, should fail", func(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool(cleanupDelta, bucketSize) + + ok := pp.IsProofInPoolEqualTo(proof1) + require.False(t, ok) + }) + + t.Run("nil provided proof, should fail", func(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool(cleanupDelta, bucketSize) + + ok := pp.IsProofInPoolEqualTo(nil) + require.False(t, ok) + }) + + t.Run("same proof, should return true", func(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool(cleanupDelta, bucketSize) + + ok := pp.UpsertProof(proof1) + require.True(t, ok) + + ok = pp.IsProofInPoolEqualTo(proof1) + require.True(t, ok) + }) + + t.Run("not equal, should return false", func(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool(cleanupDelta, bucketSize) + + ok := pp.UpsertProof(proof1) + require.True(t, ok) + + newProof1 := &block.HeaderProof{ + PubKeysBitmap: []byte("newpubKeysBitmap1"), + AggregatedSignature: []byte("newaggSig1"), + HeaderHash: []byte("hash1"), + HeaderEpoch: 1, + HeaderNonce: 1, + HeaderShardId: shardID, + } + + ok = pp.IsProofInPoolEqualTo(newProof1) + require.False(t, ok) + }) +} + +func TestProofsPool_RegisterHandler(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool(cleanupDelta, bucketSize) + + wasCalled := false + wg := sync.WaitGroup{} + wg.Add(1) + handler := func(proof data.HeaderProofHandler) { + wasCalled = true + wg.Done() + } + pp.RegisterHandler(nil) + pp.RegisterHandler(handler) + + _ = pp.AddProof(generateProof()) + + wg.Wait() + + assert.True(t, wasCalled) +} + +func TestProofsPool_CleanupProofsBehindNonce(t *testing.T) { + t.Parallel() + + t.Run("should not cleanup proofs behind delta", func(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool(cleanupDelta, bucketSize) + + _ = pp.AddProof(proof1) + _ = pp.AddProof(proof2) + _ = pp.AddProof(proof3) + _ = pp.AddProof(proof4) + + _, err := pp.GetProof(shardID, []byte("hash2")) + require.Nil(t, err) + _, err = pp.GetProof(shardID, []byte("hash3")) + require.Nil(t, err) + _, err = pp.GetProof(shardID, []byte("hash4")) + require.Nil(t, err) + }) + + t.Run("should not cleanup if nonce smaller or equal to delta", func(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool(cleanupDelta, bucketSize) + + _ = pp.AddProof(proof1) + _ = pp.AddProof(proof2) + _ = pp.AddProof(proof3) + _ = pp.AddProof(proof4) + + err := pp.CleanupProofsBehindNonce(shardID, cleanupDelta) + require.Nil(t, err) + + _, err = pp.GetProof(shardID, []byte("hash1")) + require.Nil(t, err) + _, err = pp.GetProof(shardID, []byte("hash2")) + require.Nil(t, err) + _, err = pp.GetProof(shardID, []byte("hash3")) + require.Nil(t, err) + _, err = pp.GetProof(shardID, []byte("hash4")) + require.Nil(t, err) + }) +} + +func TestProofsPool_Concurrency(t *testing.T) { + t.Parallel() + + pp := proofscache.NewProofsPool(cleanupDelta, bucketSize) + + numOperations := 1000 + + wg := sync.WaitGroup{} + wg.Add(numOperations) + + cnt := uint32(0) + + for i := 0; i < numOperations; i++ { + go func(idx int) { + switch idx % 7 { + case 0, 1, 2: + _ = pp.AddProof(generateProof()) + case 3: + _, err := pp.GetProof(generateRandomShardID(), generateRandomHash()) + if errors.Is(err, proofscache.ErrMissingProof) { + atomic.AddUint32(&cnt, 1) + } + case 4: + _, _ = pp.GetProofByNonce(generateRandomNonce(100), generateRandomShardID()) + case 5: + _ = pp.CleanupProofsBehindNonce(generateRandomShardID(), generateRandomNonce(100)) + case 6: + handler := func(proof data.HeaderProofHandler) { + } + pp.RegisterHandler(handler) + default: + assert.Fail(t, "should have not beed called") + } + + wg.Done() + }(i) + } + + wg.Wait() + + require.GreaterOrEqual(t, uint32(numOperations/3), atomic.LoadUint32(&cnt)) +} + +func generateProof() *block.HeaderProof { + return &block.HeaderProof{ + HeaderHash: generateRandomHash(), + HeaderEpoch: 1, + HeaderNonce: generateRandomNonce(100), + HeaderShardId: generateRandomShardID(), + } +} + +func generateRandomHash() []byte { + hashSuffix := generateRandomInt(100) + hash := []byte("hash_" + hashSuffix.String()) + return hash +} + +func generateRandomNonce(n int64) uint64 { + val := generateRandomInt(n) + return val.Uint64() +} + +func generateRandomShardID() uint32 { + val := generateRandomInt(3) + return uint32(val.Uint64()) +} + +func generateRandomInt(max int64) *big.Int { + rantInt, _ := rand.Int(rand.Reader, big.NewInt(max)) + return rantInt +} diff --git a/dataRetriever/errors.go b/dataRetriever/errors.go index 8b7b2f2e3dc..f7d053567ec 100644 --- a/dataRetriever/errors.go +++ b/dataRetriever/errors.go @@ -262,3 +262,12 @@ var ErrNilValidatorInfoStorage = errors.New("nil validator info storage") // ErrValidatorInfoNotFound signals that no validator info was found var ErrValidatorInfoNotFound = errors.New("validator info not found") + +// ErrNilProofsPool signals that a nil proofs pool has been provided +var ErrNilProofsPool = errors.New("nil proofs pool") + +// ErrEquivalentProofsNotFound signals that no equivalent proof found +var ErrEquivalentProofsNotFound = errors.New("equivalent proof not found") + +// ErrNilEnableEpochsHandler signals that a nil enable epochs handler has been provided +var ErrNilEnableEpochsHandler = errors.New("nil enable epochs handler") diff --git a/dataRetriever/factory/dataPoolFactory.go b/dataRetriever/factory/dataPoolFactory.go index f48dc3d3c37..e375ed2c785 100644 --- a/dataRetriever/factory/dataPoolFactory.go +++ b/dataRetriever/factory/dataPoolFactory.go @@ -11,6 +11,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/headersCache" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/dataRetriever/shardedData" "github.com/multiversx/mx-chain-go/dataRetriever/txpool" "github.com/multiversx/mx-chain-go/process" @@ -150,8 +151,10 @@ func NewDataPoolFromConfig(args ArgsDataPool) (dataRetriever.PoolsHolder, error) return nil, fmt.Errorf("%w while creating the cache for the validator info results", err) } + proofsPool := proofscache.NewProofsPool(mainConfig.ProofsPoolConfig.CleanupNonceDelta, mainConfig.ProofsPoolConfig.BucketSize) currBlockTransactions := dataPool.NewCurrentBlockTransactionsPool() currEpochValidatorInfo := dataPool.NewCurrentEpochValidatorInfoPool() + dataPoolArgs := dataPool.DataPoolArgs{ Transactions: txPool, UnsignedTransactions: uTxPool, @@ -167,6 +170,7 @@ func NewDataPoolFromConfig(args ArgsDataPool) (dataRetriever.PoolsHolder, error) PeerAuthentications: peerAuthPool, Heartbeats: heartbeatPool, ValidatorsInfo: validatorsInfo, + Proofs: proofsPool, } return dataPool.NewDataPool(dataPoolArgs) } diff --git a/dataRetriever/factory/requestersContainer/args.go b/dataRetriever/factory/requestersContainer/args.go index 96f09453cb9..61d99a35c6d 100644 --- a/dataRetriever/factory/requestersContainer/args.go +++ b/dataRetriever/factory/requestersContainer/args.go @@ -3,6 +3,7 @@ package requesterscontainer import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/p2p" @@ -23,4 +24,5 @@ type FactoryArgs struct { FullArchivePreferredPeersHolder p2p.PreferredPeersHolderHandler PeersRatingHandler dataRetriever.PeersRatingHandler SizeCheckDelta uint32 + EnableEpochsHandler common.EnableEpochsHandler } diff --git a/dataRetriever/factory/requestersContainer/baseRequestersContainerFactory.go b/dataRetriever/factory/requestersContainer/baseRequestersContainerFactory.go index 2ec10054d8d..4df84e13279 100644 --- a/dataRetriever/factory/requestersContainer/baseRequestersContainerFactory.go +++ b/dataRetriever/factory/requestersContainer/baseRequestersContainerFactory.go @@ -36,6 +36,7 @@ type baseRequestersContainerFactory struct { mainPreferredPeersHolder dataRetriever.PreferredPeersHolderHandler fullArchivePreferredPeersHolder dataRetriever.PreferredPeersHolderHandler peersRatingHandler dataRetriever.PeersRatingHandler + enableEpochsHandler common.EnableEpochsHandler numCrossShardPeers int numIntraShardPeers int numTotalPeers int @@ -344,3 +345,30 @@ func (brcf *baseRequestersContainerFactory) generateValidatorInfoRequester() err return brcf.container.Add(identifierValidatorInfo, requester) } + +func (brcf *baseRequestersContainerFactory) createEquivalentProofsRequester( + topic string, + numCrossShardPeers int, + numIntraShardPeers int, + targetShardID uint32, +) (dataRetriever.Requester, error) { + requestSender, err := brcf.createOneRequestSenderWithSpecifiedNumRequests( + topic, + EmptyExcludePeersOnTopic, + targetShardID, + numCrossShardPeers, + numIntraShardPeers, + ) + if err != nil { + return nil, err + } + + arg := requesters.ArgEquivalentProofsRequester{ + ArgBaseRequester: requesters.ArgBaseRequester{ + RequestSender: requestSender, + Marshaller: brcf.marshaller, + }, + EnableEpochsHandler: brcf.enableEpochsHandler, + } + return requesters.NewEquivalentProofsRequester(arg) +} diff --git a/dataRetriever/factory/requestersContainer/metaRequestersContainerFactory.go b/dataRetriever/factory/requestersContainer/metaRequestersContainerFactory.go index c718f5b22a1..06f0c1e3ec8 100644 --- a/dataRetriever/factory/requestersContainer/metaRequestersContainerFactory.go +++ b/dataRetriever/factory/requestersContainer/metaRequestersContainerFactory.go @@ -43,6 +43,7 @@ func NewMetaRequestersContainerFactory( numIntraShardPeers: int(numIntraShardPeers), numTotalPeers: int(args.RequesterConfig.NumTotalPeers), numFullHistoryPeers: int(args.RequesterConfig.NumFullHistoryPeers), + enableEpochsHandler: args.EnableEpochsHandler, } err := base.checkParams() @@ -85,6 +86,11 @@ func (mrcf *metaRequestersContainerFactory) Create() (dataRetriever.RequestersCo return nil, err } + err = mrcf.generateEquivalentProofsRequesters() + if err != nil { + return nil, err + } + return mrcf.container, nil } @@ -258,6 +264,47 @@ func (mrcf *metaRequestersContainerFactory) generateRewardsRequesters(topic stri return mrcf.container.AddMultiple(keys, requestersSlice) } +func (mrcf *metaRequestersContainerFactory) generateEquivalentProofsRequesters() error { + shardC := mrcf.shardCoordinator + noOfShards := shardC.NumberOfShards() + + keys := make([]string, 0) + requestersSlice := make([]dataRetriever.Requester, 0) + + // on meta should be one requester for each shard + one for ALL, similar as interceptors + for idx := uint32(0); idx < noOfShards; idx++ { + identifier := common.EquivalentProofsTopic + shardC.CommunicationIdentifier(idx) + requester, err := mrcf.createEquivalentProofsRequester( + identifier, + mrcf.numCrossShardPeers, + mrcf.numIntraShardPeers, + idx, + ) + if err != nil { + return err + } + + requestersSlice = append(requestersSlice, requester) + keys = append(keys, identifier) + } + + identifier := common.EquivalentProofsTopic + core.CommunicationIdentifierBetweenShards(core.MetachainShardId, core.AllShardId) + requester, err := mrcf.createEquivalentProofsRequester( + identifier, + mrcf.numCrossShardPeers, + mrcf.numIntraShardPeers, + core.MetachainShardId, + ) + if err != nil { + return err + } + + requestersSlice = append(requestersSlice, requester) + keys = append(keys, identifier) + + return mrcf.container.AddMultiple(keys, requestersSlice) +} + // IsInterfaceNil returns true if there is no value under the interface func (mrcf *metaRequestersContainerFactory) IsInterfaceNil() bool { return mrcf == nil diff --git a/dataRetriever/factory/requestersContainer/metaRequestersContainerFactory_test.go b/dataRetriever/factory/requestersContainer/metaRequestersContainerFactory_test.go index e68f4c7e5a5..dca3b6426a3 100644 --- a/dataRetriever/factory/requestersContainer/metaRequestersContainerFactory_test.go +++ b/dataRetriever/factory/requestersContainer/metaRequestersContainerFactory_test.go @@ -223,8 +223,10 @@ func TestMetaRequestersContainerFactory_With4ShardsShouldWork(t *testing.T) { numRequestersTrieNodes := 2 numRequestersPeerAuth := 1 numRequesterValidatorInfo := 1 + numRequesterEquivalentProofs := noOfShards + 1 totalRequesters := numRequestersShardHeadersForMetachain + numRequesterMetablocks + numRequestersMiniBlocks + - numRequestersUnsigned + numRequestersTxs + numRequestersTrieNodes + numRequestersRewards + numRequestersPeerAuth + numRequesterValidatorInfo + numRequestersUnsigned + numRequestersTxs + numRequestersTrieNodes + numRequestersRewards + numRequestersPeerAuth + + numRequesterValidatorInfo + numRequesterEquivalentProofs assert.Equal(t, totalRequesters, container.Len()) diff --git a/dataRetriever/factory/requestersContainer/shardRequestersContainerFactory.go b/dataRetriever/factory/requestersContainer/shardRequestersContainerFactory.go index d7468d5302d..014cef057c3 100644 --- a/dataRetriever/factory/requestersContainer/shardRequestersContainerFactory.go +++ b/dataRetriever/factory/requestersContainer/shardRequestersContainerFactory.go @@ -42,6 +42,7 @@ func NewShardRequestersContainerFactory( numIntraShardPeers: int(numIntraShardPeers), numTotalPeers: int(args.RequesterConfig.NumTotalPeers), numFullHistoryPeers: int(args.RequesterConfig.NumFullHistoryPeers), + enableEpochsHandler: args.EnableEpochsHandler, } err := base.checkParams() @@ -84,6 +85,11 @@ func (srcf *shardRequestersContainerFactory) Create() (dataRetriever.RequestersC return nil, err } + err = srcf.generateEquivalentProofsRequesters() + if err != nil { + return nil, err + } + return srcf.container, nil } @@ -179,6 +185,44 @@ func (srcf *shardRequestersContainerFactory) generateRewardRequester(topic strin return srcf.container.AddMultiple(keys, requestersSlice) } +func (srcf *shardRequestersContainerFactory) generateEquivalentProofsRequesters() error { + shardC := srcf.shardCoordinator + + keys := make([]string, 0) + requestersSlice := make([]dataRetriever.Requester, 0) + + // should be 2 resolvers on shards, similar as interceptors: self_META + ALL + identifier := common.EquivalentProofsTopic + shardC.CommunicationIdentifier(core.MetachainShardId) + requester, err := srcf.createEquivalentProofsRequester( + identifier, + 0, + srcf.numTotalPeers, + shardC.SelfId(), + ) + if err != nil { + return err + } + + requestersSlice = append(requestersSlice, requester) + keys = append(keys, identifier) + + identifier = common.EquivalentProofsTopic + core.CommunicationIdentifierBetweenShards(core.MetachainShardId, core.AllShardId) + requester, err = srcf.createEquivalentProofsRequester( + identifier, + srcf.numCrossShardPeers, + srcf.numIntraShardPeers, + core.MetachainShardId, + ) + if err != nil { + return err + } + + requestersSlice = append(requestersSlice, requester) + keys = append(keys, identifier) + + return srcf.container.AddMultiple(keys, requestersSlice) +} + // IsInterfaceNil returns true if there is no value under the interface func (srcf *shardRequestersContainerFactory) IsInterfaceNil() bool { return srcf == nil diff --git a/dataRetriever/factory/requestersContainer/shardRequestersContainerFactory_test.go b/dataRetriever/factory/requestersContainer/shardRequestersContainerFactory_test.go index e4c94491487..27424678169 100644 --- a/dataRetriever/factory/requestersContainer/shardRequestersContainerFactory_test.go +++ b/dataRetriever/factory/requestersContainer/shardRequestersContainerFactory_test.go @@ -5,11 +5,13 @@ import ( "strings" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/factory/requestersContainer" "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/p2p" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -253,8 +255,10 @@ func TestShardRequestersContainerFactory_With4ShardsShouldWork(t *testing.T) { numRequesterTrieNodes := 1 numRequesterPeerAuth := 1 numRequesterValidatorInfo := 1 + numRequesterEquivalentProofs := 2 totalRequesters := numRequesterTxs + numRequesterHeaders + numRequesterMiniBlocks + numRequesterMetaBlockHeaders + - numRequesterSCRs + numRequesterRewardTxs + numRequesterTrieNodes + numRequesterPeerAuth + numRequesterValidatorInfo + numRequesterSCRs + numRequesterRewardTxs + numRequesterTrieNodes + numRequesterPeerAuth + numRequesterValidatorInfo + + numRequesterEquivalentProofs assert.Equal(t, totalRequesters, container.Len()) } @@ -277,5 +281,10 @@ func getArguments() requesterscontainer.FactoryArgs { FullArchivePreferredPeersHolder: &p2pmocks.PeersHolderStub{}, PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, SizeCheckDelta: 0, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + }, } } diff --git a/dataRetriever/factory/resolverscontainer/baseResolversContainerFactory.go b/dataRetriever/factory/resolverscontainer/baseResolversContainerFactory.go index 3d0eff8eaa9..ae5216113e2 100644 --- a/dataRetriever/factory/resolverscontainer/baseResolversContainerFactory.go +++ b/dataRetriever/factory/resolverscontainer/baseResolversContainerFactory.go @@ -417,3 +417,47 @@ func (brcf *baseResolversContainerFactory) generateValidatorInfoResolver() error return brcf.container.Add(identifierValidatorInfo, validatorInfoResolver) } + +func (brcf *baseResolversContainerFactory) createEquivalentProofsResolver( + topic string, + targetShardID uint32, +) (dataRetriever.Resolver, error) { + resolverSender, err := brcf.createOneResolverSenderWithSpecifiedNumRequests( + topic, + EmptyExcludePeersOnTopic, + targetShardID, + ) + if err != nil { + return nil, err + } + + arg := resolvers.ArgEquivalentProofsResolver{ + ArgBaseResolver: resolvers.ArgBaseResolver{ + SenderResolver: resolverSender, + Marshaller: brcf.marshalizer, + AntifloodHandler: brcf.inputAntifloodHandler, + Throttler: brcf.trieNodesThrottler, + }, + DataPacker: brcf.dataPacker, + Storage: brcf.store, + EquivalentProofsPool: brcf.dataPools.Proofs(), + NonceConverter: brcf.uint64ByteSliceConverter, + IsFullHistoryNode: brcf.isFullHistoryNode, + } + resolver, err := resolvers.NewEquivalentProofsResolver(arg) + if err != nil { + return nil, err + } + + err = brcf.mainMessenger.RegisterMessageProcessor(resolver.RequestTopic(), common.DefaultResolversIdentifier, resolver) + if err != nil { + return nil, err + } + + err = brcf.fullArchiveMessenger.RegisterMessageProcessor(resolver.RequestTopic(), common.DefaultResolversIdentifier, resolver) + if err != nil { + return nil, err + } + + return resolver, nil +} diff --git a/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory.go b/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory.go index b72f8c3154a..a84deb54aa4 100644 --- a/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory.go +++ b/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory.go @@ -131,6 +131,11 @@ func (mrcf *metaResolversContainerFactory) Create() (dataRetriever.ResolversCont return nil, err } + err = mrcf.generateEquivalentProofsResolvers() + if err != nil { + return nil, err + } + return mrcf.container, nil } @@ -204,7 +209,7 @@ func (mrcf *metaResolversContainerFactory) createShardHeaderResolver( } // TODO change this data unit creation method through a factory or func - hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(shardID) + hdrNonceHashDataUnit := dataRetriever.GetHdrNonceHashDataUnit(shardID) hdrNonceStore, err := mrcf.store.GetStorer(hdrNonceHashDataUnit) if err != nil { return nil, err @@ -367,6 +372,43 @@ func (mrcf *metaResolversContainerFactory) generateRewardsResolvers( return mrcf.container.AddMultiple(keys, resolverSlice) } +func (mrcf *metaResolversContainerFactory) generateEquivalentProofsResolvers() error { + shardC := mrcf.shardCoordinator + noOfShards := shardC.NumberOfShards() + + keys := make([]string, 0) + resolversSlice := make([]dataRetriever.Resolver, 0) + + // on meta should be one resolver for each shard + one for ALL, similar as interceptors + for idx := uint32(0); idx < noOfShards; idx++ { + identifier := common.EquivalentProofsTopic + shardC.CommunicationIdentifier(idx) + resolver, err := mrcf.createEquivalentProofsResolver( + identifier, + idx, + ) + if err != nil { + return err + } + + resolversSlice = append(resolversSlice, resolver) + keys = append(keys, identifier) + } + + identifier := common.EquivalentProofsTopic + core.CommunicationIdentifierBetweenShards(core.MetachainShardId, core.AllShardId) + resolver, err := mrcf.createEquivalentProofsResolver( + identifier, + core.MetachainShardId, + ) + if err != nil { + return err + } + + resolversSlice = append(resolversSlice, resolver) + keys = append(keys, identifier) + + return mrcf.container.AddMultiple(keys, resolversSlice) +} + // IsInterfaceNil returns true if there is no value under the interface func (mrcf *metaResolversContainerFactory) IsInterfaceNil() bool { return mrcf == nil diff --git a/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory_test.go b/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory_test.go index 755672384cd..533d682914b 100644 --- a/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory_test.go +++ b/dataRetriever/factory/resolverscontainer/metaResolversContainerFactory_test.go @@ -6,6 +6,8 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/factory/resolverscontainer" @@ -15,11 +17,11 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" - "github.com/stretchr/testify/assert" ) func createStubMessengerForMeta(matchStrToErrOnCreate string, matchStrToErrOnRegister string) p2p.Messenger { @@ -56,7 +58,7 @@ func createDataPoolsForMeta() dataRetriever.PoolsHolder { return &mock.HeadersCacherStub{} }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() @@ -67,6 +69,9 @@ func createDataPoolsForMeta() dataRetriever.PoolsHolder { RewardTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } return pools @@ -338,8 +343,10 @@ func TestMetaResolversContainerFactory_With4ShardsShouldWork(t *testing.T) { numResolversTrieNodes := 2 numResolversPeerAuth := 1 numResolverValidatorInfo := 1 + numResolverEquivalentProofs := noOfShards + 1 totalResolvers := numResolversShardHeadersForMetachain + numResolverMetablocks + numResolversMiniBlocks + - numResolversUnsigned + numResolversTxs + numResolversTrieNodes + numResolversRewards + numResolversPeerAuth + numResolverValidatorInfo + numResolversUnsigned + numResolversTxs + numResolversTrieNodes + numResolversRewards + numResolversPeerAuth + + numResolverValidatorInfo + numResolverEquivalentProofs assert.Equal(t, totalResolvers, container.Len()) assert.Equal(t, totalResolvers, registerMainCnt) diff --git a/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory.go b/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory.go index f24beaa4331..3c1f374e4a8 100644 --- a/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory.go +++ b/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory.go @@ -129,6 +129,11 @@ func (srcf *shardResolversContainerFactory) Create() (dataRetriever.ResolversCon return nil, err } + err = srcf.generateEquivalentProofsResolvers() + if err != nil { + return nil, err + } + return srcf.container, nil } @@ -149,7 +154,7 @@ func (srcf *shardResolversContainerFactory) generateHeaderResolvers() error { return err } - hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(shardC.SelfId()) + hdrNonceHashDataUnit := dataRetriever.GetHdrNonceHashDataUnit(shardC.SelfId()) hdrNonceStore, err := srcf.store.GetStorer(hdrNonceHashDataUnit) if err != nil { return err @@ -286,6 +291,40 @@ func (srcf *shardResolversContainerFactory) generateRewardResolver( return srcf.container.AddMultiple(keys, resolverSlice) } +func (srcf *shardResolversContainerFactory) generateEquivalentProofsResolvers() error { + shardC := srcf.shardCoordinator + + keys := make([]string, 0) + resolversSlice := make([]dataRetriever.Resolver, 0) + + // should be 2 resolvers on shards, similar as interceptors: self_META + ALL + identifier := common.EquivalentProofsTopic + shardC.CommunicationIdentifier(core.MetachainShardId) + resolver, err := srcf.createEquivalentProofsResolver( + identifier, + shardC.SelfId(), + ) + if err != nil { + return err + } + + resolversSlice = append(resolversSlice, resolver) + keys = append(keys, identifier) + + identifier = common.EquivalentProofsTopic + core.CommunicationIdentifierBetweenShards(core.MetachainShardId, core.AllShardId) + resolver, err = srcf.createEquivalentProofsResolver( + identifier, + core.MetachainShardId, + ) + if err != nil { + return err + } + + resolversSlice = append(resolversSlice, resolver) + keys = append(keys, identifier) + + return srcf.container.AddMultiple(keys, resolversSlice) +} + // IsInterfaceNil returns true if there is no value under the interface func (srcf *shardResolversContainerFactory) IsInterfaceNil() bool { return srcf == nil diff --git a/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory_test.go b/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory_test.go index ca97015f3ae..1d9e8c5a678 100644 --- a/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory_test.go +++ b/dataRetriever/factory/resolverscontainer/shardResolversContainerFactory_test.go @@ -6,6 +6,8 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/factory/resolverscontainer" @@ -15,11 +17,11 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" - "github.com/stretchr/testify/assert" ) var errExpected = errors.New("expected error") @@ -63,10 +65,10 @@ func createDataPoolsForShard() dataRetriever.PoolsHolder { return &mock.HeadersCacherStub{} } pools.MiniBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.PeerChangesBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.UnsignedTransactionsCalled = func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() @@ -74,6 +76,9 @@ func createDataPoolsForShard() dataRetriever.PoolsHolder { pools.RewardTransactionsCalled = func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } return pools } @@ -450,8 +455,10 @@ func TestShardResolversContainerFactory_With4ShardsShouldWork(t *testing.T) { numResolverTrieNodes := 1 numResolverPeerAuth := 1 numResolverValidatorInfo := 1 + numResolverEquivalentProofs := 2 totalResolvers := numResolverTxs + numResolverHeaders + numResolverMiniBlocks + numResolverMetaBlockHeaders + - numResolverSCRs + numResolverRewardTxs + numResolverTrieNodes + numResolverPeerAuth + numResolverValidatorInfo + numResolverSCRs + numResolverRewardTxs + numResolverTrieNodes + numResolverPeerAuth + numResolverValidatorInfo + + numResolverEquivalentProofs assert.Equal(t, totalResolvers, container.Len()) assert.Equal(t, totalResolvers, registerMainCnt) diff --git a/dataRetriever/factory/storageRequestersContainer/baseRequestersContainerFactory.go b/dataRetriever/factory/storageRequestersContainer/baseRequestersContainerFactory.go index 2682231a768..72e68d61c77 100644 --- a/dataRetriever/factory/storageRequestersContainer/baseRequestersContainerFactory.go +++ b/dataRetriever/factory/storageRequestersContainer/baseRequestersContainerFactory.go @@ -266,3 +266,21 @@ func (brcf *baseRequestersContainerFactory) generateValidatorInfoRequester() err return brcf.container.Add(identifierValidatorInfo, validatorInfoRequester) } + +func (brcf *baseRequestersContainerFactory) createEquivalentProofsRequester( + topic string, +) (dataRetriever.Requester, error) { + args := storagerequesters.ArgEquivalentProofsRequester{ + Messenger: brcf.messenger, + ResponseTopicName: topic, + ManualEpochStartNotifier: brcf.manualEpochStartNotifier, + ChanGracefullyClose: brcf.chanGracefullyClose, + DelayBeforeGracefulClose: defaultBeforeGracefulClose, + NonceConverter: brcf.uint64ByteSliceConverter, + Storage: brcf.store, + Marshaller: brcf.marshalizer, + EnableEpochsHandler: brcf.enableEpochsHandler, + } + + return storagerequesters.NewEquivalentProofsRequester(args) +} diff --git a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory.go b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory.go index 9277a29a991..e430ff170dc 100644 --- a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory.go +++ b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory.go @@ -1,6 +1,8 @@ package storagerequesterscontainer import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/factory/containers" storagerequesters "github.com/multiversx/mx-chain-go/dataRetriever/storageRequesters" @@ -73,6 +75,11 @@ func (mrcf *metaRequestersContainerFactory) Create() (dataRetriever.RequestersCo return nil, err } + err = mrcf.generateEquivalentProofsRequesters() + if err != nil { + return nil, err + } + return mrcf.container, nil } @@ -107,7 +114,7 @@ func (mrcf *metaRequestersContainerFactory) createShardHeaderRequester( } // TODO change this data unit creation method through a factory or func - hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(shardID) + hdrNonceHashDataUnit := dataRetriever.GetHdrNonceHashDataUnit(shardID) hdrNonceStore, err := mrcf.store.GetStorer(hdrNonceHashDataUnit) if err != nil { return nil, err @@ -196,6 +203,37 @@ func (mrcf *metaRequestersContainerFactory) generateRewardsRequesters( return mrcf.container.AddMultiple(keys, requestersSlice) } +func (mrcf *metaRequestersContainerFactory) generateEquivalentProofsRequesters() error { + shardC := mrcf.shardCoordinator + noOfShards := shardC.NumberOfShards() + + keys := make([]string, 0) + requestersSlice := make([]dataRetriever.Requester, 0) + + // on meta should be one requester for each shard + one for ALL, similar as interceptors + for idx := uint32(0); idx < noOfShards; idx++ { + identifier := common.EquivalentProofsTopic + shardC.CommunicationIdentifier(idx) + requester, err := mrcf.createEquivalentProofsRequester(identifier) + if err != nil { + return err + } + + requestersSlice = append(requestersSlice, requester) + keys = append(keys, identifier) + } + + identifier := common.EquivalentProofsTopic + core.CommunicationIdentifierBetweenShards(core.MetachainShardId, core.AllShardId) + requester, err := mrcf.createEquivalentProofsRequester(identifier) + if err != nil { + return err + } + + requestersSlice = append(requestersSlice, requester) + keys = append(keys, identifier) + + return mrcf.container.AddMultiple(keys, requestersSlice) +} + // IsInterfaceNil returns true if there is no value under the interface func (mrcf *metaRequestersContainerFactory) IsInterfaceNil() bool { return mrcf == nil diff --git a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go index c166223ad20..6a44b37b153 100644 --- a/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go +++ b/dataRetriever/factory/storageRequestersContainer/metaRequestersContainerFactory_test.go @@ -181,9 +181,10 @@ func TestMetaRequestersContainerFactory_With4ShardsShouldWork(t *testing.T) { numRequestersTxs := noOfShards + 1 numPeerAuthentication := 1 numValidatorInfo := 1 + numEquivalentProofs := noOfShards + 1 totalRequesters := numRequestersShardHeadersForMetachain + numRequesterMetablocks + numRequestersMiniBlocks + numRequestersUnsigned + numRequestersTxs + numRequestersRewards + numPeerAuthentication + - numValidatorInfo + numValidatorInfo + numEquivalentProofs assert.Equal(t, totalRequesters, container.Len()) assert.Equal(t, totalRequesters, container.Len()) diff --git a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory.go b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory.go index c0bacd54a14..2380a6380cd 100644 --- a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory.go +++ b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory.go @@ -2,6 +2,7 @@ package storagerequesterscontainer import ( "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/factory/containers" storagerequesters "github.com/multiversx/mx-chain-go/dataRetriever/storageRequesters" @@ -74,6 +75,11 @@ func (srcf *shardRequestersContainerFactory) Create() (dataRetriever.RequestersC return nil, err } + err = srcf.generateEquivalentProofsRequesters() + if err != nil { + return nil, err + } + return srcf.container, nil } @@ -88,7 +94,7 @@ func (srcf *shardRequestersContainerFactory) generateHeaderRequesters() error { return err } - hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(shardC.SelfId()) + hdrNonceHashDataUnit := dataRetriever.GetHdrNonceHashDataUnit(shardC.SelfId()) hdrNonceStore, err := srcf.store.GetStorer(hdrNonceHashDataUnit) if err != nil { return err @@ -166,6 +172,34 @@ func (srcf *shardRequestersContainerFactory) generateRewardRequester( return srcf.container.AddMultiple(keys, requesterSlice) } +func (srcf *shardRequestersContainerFactory) generateEquivalentProofsRequesters() error { + shardC := srcf.shardCoordinator + + keys := make([]string, 0) + requestersSlice := make([]dataRetriever.Requester, 0) + + // should be 2 resolvers on shards, similar as interceptors: self_META + ALL + identifier := common.EquivalentProofsTopic + shardC.CommunicationIdentifier(core.MetachainShardId) + requester, err := srcf.createEquivalentProofsRequester(identifier) + if err != nil { + return err + } + + requestersSlice = append(requestersSlice, requester) + keys = append(keys, identifier) + + identifier = common.EquivalentProofsTopic + core.CommunicationIdentifierBetweenShards(core.MetachainShardId, core.AllShardId) + requester, err = srcf.createEquivalentProofsRequester(identifier) + if err != nil { + return err + } + + requestersSlice = append(requestersSlice, requester) + keys = append(keys, identifier) + + return srcf.container.AddMultiple(keys, requestersSlice) +} + // IsInterfaceNil returns true if there is no value under the interface func (srcf *shardRequestersContainerFactory) IsInterfaceNil() bool { return srcf == nil diff --git a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go index ed1e4a69bdf..ecf848b35ac 100644 --- a/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go +++ b/dataRetriever/factory/storageRequestersContainer/shardRequestersContainerFactory_test.go @@ -185,9 +185,10 @@ func TestShardRequestersContainerFactory_With4ShardsShouldWork(t *testing.T) { numRequesterMetaBlockHeaders := 1 numPeerAuthentication := 1 numValidatorInfo := 1 + numEquivalentProofs := 2 totalRequesters := numRequesterTxs + numRequesterHeaders + numRequesterMiniBlocks + numRequesterMetaBlockHeaders + numRequesterSCRs + numRequesterRewardTxs + - numPeerAuthentication + numValidatorInfo + numPeerAuthentication + numValidatorInfo + numEquivalentProofs assert.Equal(t, totalRequesters, container.Len()) } diff --git a/dataRetriever/interface.go b/dataRetriever/interface.go index 930b6aca124..f9d68fa9f92 100644 --- a/dataRetriever/interface.go +++ b/dataRetriever/interface.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/counting" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" @@ -21,7 +22,7 @@ type ResolverThrottler interface { // Resolver defines what a data resolver should do type Resolver interface { - ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error + ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) SetDebugHandler(handler DebugHandler) error Close() error IsInterfaceNil() bool @@ -240,6 +241,7 @@ type PoolsHolder interface { PeerAuthentications() storage.Cacher Heartbeats() storage.Cacher ValidatorsInfo() ShardedDataCacherNotifier + Proofs() ProofsPool Close() error IsInterfaceNil() bool } @@ -357,3 +359,16 @@ type PeerAuthenticationPayloadValidator interface { ValidateTimestamp(payloadTimestamp int64) error IsInterfaceNil() bool } + +// ProofsPool defines the behaviour of a proofs pool components +type ProofsPool interface { + AddProof(headerProof data.HeaderProofHandler) bool + UpsertProof(headerProof data.HeaderProofHandler) bool + RegisterHandler(handler func(headerProof data.HeaderProofHandler)) + CleanupProofsBehindNonce(shardID uint32, nonce uint64) error + GetProof(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) + GetProofByNonce(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) + HasProof(shardID uint32, headerHash []byte) bool + IsProofInPoolEqualTo(headerProof data.HeaderProofHandler) bool + IsInterfaceNil() bool +} diff --git a/dataRetriever/mock/headerResolverStub.go b/dataRetriever/mock/headerResolverStub.go index fa87219b082..bb5fc6c75ab 100644 --- a/dataRetriever/mock/headerResolverStub.go +++ b/dataRetriever/mock/headerResolverStub.go @@ -2,16 +2,17 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/core" + "github.com/pkg/errors" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/p2p" - "github.com/pkg/errors" ) var errNotImplemented = errors.New("not implemented") // HeaderResolverStub - type HeaderResolverStub struct { - ProcessReceivedMessageCalled func(message p2p.MessageP2P) error + ProcessReceivedMessageCalled func(message p2p.MessageP2P) ([]byte, error) SetEpochHandlerCalled func(epochHandler dataRetriever.EpochHandler) error SetDebugHandlerCalled func(handler dataRetriever.DebugHandler) error CloseCalled func() error @@ -26,12 +27,12 @@ func (hrs *HeaderResolverStub) SetEpochHandler(epochHandler dataRetriever.EpochH } // ProcessReceivedMessage - -func (hrs *HeaderResolverStub) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) error { +func (hrs *HeaderResolverStub) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) ([]byte, error) { if hrs.ProcessReceivedMessageCalled != nil { return hrs.ProcessReceivedMessageCalled(message) } - return errNotImplemented + return nil, errNotImplemented } // SetDebugHandler - diff --git a/dataRetriever/mock/resolverStub.go b/dataRetriever/mock/resolverStub.go index c667c9459b2..a1d0b546f4b 100644 --- a/dataRetriever/mock/resolverStub.go +++ b/dataRetriever/mock/resolverStub.go @@ -2,19 +2,20 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/p2p" ) // ResolverStub - type ResolverStub struct { - ProcessReceivedMessageCalled func(message p2p.MessageP2P) error + ProcessReceivedMessageCalled func(message p2p.MessageP2P) ([]byte, error) SetDebugHandlerCalled func(handler dataRetriever.DebugHandler) error CloseCalled func() error } // ProcessReceivedMessage - -func (rs *ResolverStub) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) error { +func (rs *ResolverStub) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) ([]byte, error) { return rs.ProcessReceivedMessageCalled(message) } diff --git a/dataRetriever/provider/miniBlocks_test.go b/dataRetriever/provider/miniBlocks_test.go index dc0e4f206e8..271d8ef55e6 100644 --- a/dataRetriever/provider/miniBlocks_test.go +++ b/dataRetriever/provider/miniBlocks_test.go @@ -8,14 +8,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" dataBlock "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/dataRetriever/provider" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockMiniblockProviderArgs( @@ -37,7 +38,7 @@ func createMockMiniblockProviderArgs( return nil, fmt.Errorf("not found") }, }, - MiniBlockPool: &testscommon.CacherStub{ + MiniBlockPool: &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if isByteSliceInSlice(key, dataPoolExistingHashes) { return &dataBlock.MiniBlock{}, true @@ -105,7 +106,7 @@ func TestNewMiniBlockProvider_ShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- GetMiniBlocksFromPool +// ------- GetMiniBlocksFromPool func TestMiniBlockProvider_GetMiniBlocksFromPoolFoundInPoolShouldReturn(t *testing.T) { t.Parallel() @@ -140,7 +141,7 @@ func TestMiniBlockProvider_GetMiniBlocksFromPoolWrongTypeInPoolShouldNotReturn(t hashes := [][]byte{[]byte("hash1"), []byte("hash2")} arg := createMockMiniblockProviderArgs(hashes, nil) - arg.MiniBlockPool = &testscommon.CacherStub{ + arg.MiniBlockPool = &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return "not a miniblock", true }, @@ -153,7 +154,7 @@ func TestMiniBlockProvider_GetMiniBlocksFromPoolWrongTypeInPoolShouldNotReturn(t assert.Equal(t, hashes, missingHashes) } -//------- GetMiniBlocks +// ------- GetMiniBlocks func TestMiniBlockProvider_GetMiniBlocksFoundInPoolShouldReturn(t *testing.T) { t.Parallel() diff --git a/dataRetriever/requestHandlers/interface.go b/dataRetriever/requestHandlers/interface.go index 13bac95037c..7db455a0f15 100644 --- a/dataRetriever/requestHandlers/interface.go +++ b/dataRetriever/requestHandlers/interface.go @@ -26,3 +26,8 @@ type HeaderRequester interface { NonceRequester EpochRequester } + +// EquivalentProofsRequester defines what an equivalent proofs requester can do +type EquivalentProofsRequester interface { + RequestDataFromNonce(nonceShardKey []byte, epoch uint32) error +} diff --git a/dataRetriever/requestHandlers/requestHandler.go b/dataRetriever/requestHandlers/requestHandler.go index 91e4992aee3..8b88fa98208 100644 --- a/dataRetriever/requestHandlers/requestHandler.go +++ b/dataRetriever/requestHandlers/requestHandler.go @@ -2,6 +2,7 @@ package requestHandlers import ( "encoding/binary" + "encoding/hex" "fmt" "runtime/debug" "sync" @@ -31,6 +32,7 @@ const uniqueHeadersSuffix = "hdr" const uniqueMetaHeadersSuffix = "mhdr" const uniqueTrieNodesSuffix = "tn" const uniqueValidatorInfoSuffix = "vi" +const uniqueEquivalentProofSuffix = "eqp" // TODO move the keys definitions that are whitelisted in core and use them in InterceptedData implementations, Identifiers() function @@ -299,9 +301,11 @@ func (rrh *resolverRequestHandler) RequestShardHeader(shardID uint32, hash []byt return } + epoch := rrh.getEpoch() log.Debug("requesting shard header from network", "shard", shardID, "hash", hash, + "epoch", epoch, ) headerRequester, err := rrh.getShardHeaderRequester(shardID) @@ -315,7 +319,6 @@ func (rrh *resolverRequestHandler) RequestShardHeader(shardID uint32, hash []byt rrh.whiteList.Add([][]byte{hash}) - epoch := rrh.getEpoch() err = headerRequester.RequestDataFromHash(hash, epoch) if err != nil { log.Debug("RequestShardHeader.RequestDataFromHash", @@ -867,3 +870,140 @@ func (rrh *resolverRequestHandler) RequestPeerAuthenticationsByHashes(destShardI ) } } + +// RequestEquivalentProofByHash asks for equivalent proof for the provided header hash +func (rrh *resolverRequestHandler) RequestEquivalentProofByHash(headerShard uint32, headerHash []byte) { + if !rrh.testIfRequestIsNeeded(headerHash, uniqueEquivalentProofSuffix) { + return + } + + epoch := rrh.getEpoch() + encodedHash := hex.EncodeToString(headerHash) + log.Debug("requesting equivalent proof from network", + "headerHash", encodedHash, + "shard", headerShard, + "epoch", epoch, + ) + + requester, err := rrh.getEquivalentProofsRequester(headerShard) + if err != nil { + log.Error("RequestEquivalentProofByHash.getEquivalentProofsRequester", + "error", err.Error(), + "headerHash", encodedHash, + "epoch", epoch, + ) + return + } + + rrh.whiteList.Add([][]byte{headerHash}) + + requestKey := fmt.Sprintf("%s-%d", encodedHash, headerShard) + err = requester.RequestDataFromHash([]byte(requestKey), epoch) + if err != nil { + log.Debug("RequestEquivalentProofByHash.RequestDataFromHash", + "error", err.Error(), + "headerHash", encodedHash, + "headerShard", headerShard, + "epoch", epoch, + ) + return + } + + rrh.addRequestedItems([][]byte{headerHash}, uniqueEquivalentProofSuffix) +} + +// RequestEquivalentProofByNonce asks for equivalent proof for the provided header nonce +func (rrh *resolverRequestHandler) RequestEquivalentProofByNonce(headerShard uint32, headerNonce uint64) { + key := common.GetEquivalentProofNonceShardKey(headerNonce, headerShard) + if !rrh.testIfRequestIsNeeded([]byte(key), uniqueEquivalentProofSuffix) { + return + } + + epoch := rrh.getEpoch() + log.Debug("requesting equivalent proof by nonce from network", + "headerNonce", headerNonce, + "headerShard", headerShard, + "epoch", epoch, + ) + + requester, err := rrh.getEquivalentProofsRequester(headerShard) + if err != nil { + log.Error("RequestEquivalentProofByNonce.getEquivalentProofsRequester", + "error", err.Error(), + "headerNonce", headerNonce, + ) + return + } + + proofsRequester, ok := requester.(EquivalentProofsRequester) + if !ok { + log.Warn("wrong assertion type when creating equivalent proofs requester") + return + } + + rrh.whiteList.Add([][]byte{[]byte(key)}) + + err = proofsRequester.RequestDataFromNonce([]byte(key), epoch) + if err != nil { + log.Debug("RequestEquivalentProofByNonce.RequestDataFromNonce", + "error", err.Error(), + "headerNonce", headerNonce, + "headerShard", headerShard, + "epoch", epoch, + ) + return + } + + rrh.addRequestedItems([][]byte{[]byte(key)}, uniqueEquivalentProofSuffix) +} + +func (rrh *resolverRequestHandler) getEquivalentProofsRequester(headerShard uint32) (dataRetriever.Requester, error) { + // there are multiple scenarios for equivalent proofs: + // 1. self meta requesting meta proof -> should request on equivalentProofs_ALL + // 2. self meta requesting shard proof -> should request on equivalentProofs_shard_META + // 3. self shard requesting intra proof -> should request on equivalentProofs_self_META + // 4. self shard requesting meta proof -> should request on equivalentProofs_ALL + // 4. self shard requesting cross proof -> should never happen! + + isSelfMeta := rrh.shardID == core.MetachainShardId + isRequestForMeta := headerShard == core.MetachainShardId + shardIdMissmatch := rrh.shardID != headerShard && !isRequestForMeta && !isSelfMeta + isRequestInvalid := !isSelfMeta && shardIdMissmatch + if isRequestInvalid { + return nil, dataRetriever.ErrBadRequest + } + + if isRequestForMeta { + topic := common.EquivalentProofsTopic + core.CommunicationIdentifierBetweenShards(core.MetachainShardId, core.AllShardId) + requester, err := rrh.requestersFinder.MetaChainRequester(topic) + if err != nil { + err = fmt.Errorf("%w, topic: %s, current shard ID: %d, requested header shard ID: %d", + err, topic, rrh.shardID, headerShard) + + log.Warn("available requesters in container", + "requesters", rrh.requestersFinder.RequesterKeys(), + ) + return nil, err + } + + return requester, nil + } + + crossShardID := core.MetachainShardId + if isSelfMeta { + crossShardID = headerShard + } + + requester, err := rrh.requestersFinder.CrossShardRequester(common.EquivalentProofsTopic, crossShardID) + if err != nil { + err = fmt.Errorf("%w, base topic: %s, current shard ID: %d, cross shard ID: %d", + err, common.EquivalentProofsTopic, rrh.shardID, crossShardID) + + log.Warn("available requesters in container", + "requesters", rrh.requestersFinder.RequesterKeys(), + ) + return nil, err + } + + return requester, nil +} diff --git a/dataRetriever/requestHandlers/requestHandler_test.go b/dataRetriever/requestHandlers/requestHandler_test.go index 48d27f46217..7717c10f7e9 100644 --- a/dataRetriever/requestHandlers/requestHandler_test.go +++ b/dataRetriever/requestHandlers/requestHandler_test.go @@ -2,6 +2,8 @@ package requestHandlers import ( "bytes" + "encoding/hex" + "fmt" "sync/atomic" "testing" "time" @@ -1997,3 +1999,244 @@ func TestResolverRequestHandler_IsInterfaceNil(t *testing.T) { ) require.False(t, rrh.IsInterfaceNil()) } + +func TestResolverRequestHandler_RequestEquivalentProofByHash(t *testing.T) { + t.Parallel() + + t.Run("hash already requested should work", func(t *testing.T) { + t.Parallel() + + providedHash := []byte("provided hash") + rrh, _ := NewResolverRequestHandler( + &dataRetrieverMocks.RequestersFinderStub{ + MetaChainRequesterCalled: func(baseTopic string) (requester dataRetriever.Requester, e error) { + require.Fail(t, "should not have been called") + return nil, nil + }, + }, + &mock.RequestedItemsHandlerStub{ + HasCalled: func(key string) bool { + return true + }, + }, + &mock.WhiteListHandlerStub{ + AddCalled: func(keys [][]byte) { + require.Fail(t, "should not have been called") + }, + }, + 100, + 0, + time.Second, + ) + + rrh.RequestEquivalentProofByHash(core.MetachainShardId, providedHash) + }) + t.Run("invalid cross-shard request should early exit", func(t *testing.T) { + t.Parallel() + + providedHash := []byte("provided hash") + rrh, _ := NewResolverRequestHandler( + &dataRetrieverMocks.RequestersFinderStub{}, + &mock.RequestedItemsHandlerStub{}, + &mock.WhiteListHandlerStub{ + AddCalled: func(keys [][]byte) { + require.Fail(t, "should not have been called") + }, + }, + 100, + 0, + time.Second, + ) + + rrh.RequestEquivalentProofByHash(1, providedHash) + }) + t.Run("missing metachain requester should early exit", func(t *testing.T) { + t.Parallel() + + providedHash := []byte("provided hash") + rrh, _ := NewResolverRequestHandler( + &dataRetrieverMocks.RequestersFinderStub{ + MetaChainRequesterCalled: func(baseTopic string) (dataRetriever.Requester, error) { + return nil, errExpected + }, + }, + &mock.RequestedItemsHandlerStub{}, + &mock.WhiteListHandlerStub{ + AddCalled: func(keys [][]byte) { + require.Fail(t, "should not have been called") + }, + }, + 100, + core.MetachainShardId, + time.Second, + ) + + rrh.RequestEquivalentProofByHash(core.MetachainShardId, providedHash) + }) + t.Run("missing cross-shard requester should early exit", func(t *testing.T) { + t.Parallel() + + providedHash := []byte("provided hash") + rrh, _ := NewResolverRequestHandler( + &dataRetrieverMocks.RequestersFinderStub{ + CrossShardRequesterCalled: func(baseTopic string, crossShard uint32) (dataRetriever.Requester, error) { + return nil, errExpected + }, + }, + &mock.RequestedItemsHandlerStub{}, + &mock.WhiteListHandlerStub{ + AddCalled: func(keys [][]byte) { + require.Fail(t, "should not have been called") + }, + }, + 100, + 0, + time.Second, + ) + + rrh.RequestEquivalentProofByHash(1, providedHash) + }) + t.Run("MetaChainRequester returns error", func(t *testing.T) { + t.Parallel() + + providedHash := []byte("provided hash") + res := &dataRetrieverMocks.RequesterStub{ + RequestDataFromHashCalled: func(hash []byte, epoch uint32) error { + require.Fail(t, "should not have been called") + + return nil + }, + } + + rrh, _ := NewResolverRequestHandler( + &dataRetrieverMocks.RequestersFinderStub{ + MetaChainRequesterCalled: func(baseTopic string) (requester dataRetriever.Requester, e error) { + return res, errExpected + }, + }, + &mock.RequestedItemsHandlerStub{}, + &mock.WhiteListHandlerStub{}, + 100, + core.MetachainShardId, + time.Second, + ) + + rrh.RequestEquivalentProofByHash(core.MetachainShardId, providedHash) + }) + t.Run("CrossChainRequester returns error", func(t *testing.T) { + t.Parallel() + + providedHash := []byte("provided hash") + res := &dataRetrieverMocks.RequesterStub{ + RequestDataFromHashCalled: func(hash []byte, epoch uint32) error { + require.Fail(t, "should not have been called") + return nil + }, + } + + rrh, _ := NewResolverRequestHandler( + &dataRetrieverMocks.RequestersFinderStub{ + CrossShardRequesterCalled: func(baseTopic string, crossShard uint32) (dataRetriever.Requester, error) { + return res, errExpected + }, + }, + &mock.RequestedItemsHandlerStub{}, + &mock.WhiteListHandlerStub{}, + 100, + 0, + time.Second, + ) + + rrh.RequestEquivalentProofByHash(0, providedHash) + }) + t.Run("RequestDataFromHash returns error", func(t *testing.T) { + t.Parallel() + + providedHash := []byte("provided hash") + res := &dataRetrieverMocks.RequesterStub{ + RequestDataFromHashCalled: func(hash []byte, epoch uint32) error { + return errExpected + }, + } + + rrh, _ := NewResolverRequestHandler( + &dataRetrieverMocks.RequestersFinderStub{ + MetaChainRequesterCalled: func(baseTopic string) (requester dataRetriever.Requester, e error) { + return res, nil + }, + }, + &mock.RequestedItemsHandlerStub{ + AddCalled: func(key string) error { + require.Fail(t, "should not have been called") + return nil + }, + }, + &mock.WhiteListHandlerStub{}, + 100, + core.MetachainShardId, + time.Second, + ) + + rrh.RequestEquivalentProofByHash(core.MetachainShardId, providedHash) + }) + t.Run("should work shard 0 requesting from 0", func(t *testing.T) { + t.Parallel() + + providedHash := []byte("provided hash") + wasCalled := false + res := &dataRetrieverMocks.RequesterStub{ + RequestDataFromHashCalled: func(hash []byte, epoch uint32) error { + key := fmt.Sprintf("%s-%d", hex.EncodeToString(providedHash), 0) + assert.True(t, bytes.Equal([]byte(key), hash)) + wasCalled = true + return nil + }, + } + + rrh, _ := NewResolverRequestHandler( + &dataRetrieverMocks.RequestersFinderStub{ + CrossShardRequesterCalled: func(baseTopic string, crossShard uint32) (dataRetriever.Requester, error) { + return res, nil + }, + }, + &mock.RequestedItemsHandlerStub{}, + &mock.WhiteListHandlerStub{}, + 100, + 0, + time.Second, + ) + + rrh.RequestEquivalentProofByHash(0, providedHash) + assert.True(t, wasCalled) + }) + t.Run("should work shard meta requesting from 0", func(t *testing.T) { + t.Parallel() + + providedHash := []byte("provided hash") + wasCalled := false + res := &dataRetrieverMocks.RequesterStub{ + RequestDataFromHashCalled: func(hash []byte, epoch uint32) error { + key := fmt.Sprintf("%s-%d", hex.EncodeToString(providedHash), 0) + assert.True(t, bytes.Equal([]byte(key), hash)) + wasCalled = true + return nil + }, + } + + rrh, _ := NewResolverRequestHandler( + &dataRetrieverMocks.RequestersFinderStub{ + CrossShardRequesterCalled: func(baseTopic string, crossShard uint32) (dataRetriever.Requester, error) { + return res, nil + }, + }, + &mock.RequestedItemsHandlerStub{}, + &mock.WhiteListHandlerStub{}, + 100, + core.MetachainShardId, + time.Second, + ) + + rrh.RequestEquivalentProofByHash(0, providedHash) + assert.True(t, wasCalled) + }) +} diff --git a/dataRetriever/requestHandlers/requesters/equivalentProofsRequester.go b/dataRetriever/requestHandlers/requesters/equivalentProofsRequester.go new file mode 100644 index 00000000000..f5f92d9a868 --- /dev/null +++ b/dataRetriever/requestHandlers/requesters/equivalentProofsRequester.go @@ -0,0 +1,87 @@ +package requesters + +import ( + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/dataRetriever" +) + +// ArgEquivalentProofsRequester is the argument structure used to create a new equivalent proofs requester instance +type ArgEquivalentProofsRequester struct { + ArgBaseRequester + EnableEpochsHandler common.EnableEpochsHandler +} + +type equivalentProofsRequester struct { + *baseRequester + enableEpochsHandler common.EnableEpochsHandler +} + +// NewEquivalentProofsRequester returns a new instance of equivalent proofs requester +func NewEquivalentProofsRequester(args ArgEquivalentProofsRequester) (*equivalentProofsRequester, error) { + err := checkArgBase(args.ArgBaseRequester) + if err != nil { + return nil, err + } + + if check.IfNil(args.EnableEpochsHandler) { + return nil, dataRetriever.ErrNilEnableEpochsHandler + } + + return &equivalentProofsRequester{ + baseRequester: createBaseRequester(args.ArgBaseRequester), + enableEpochsHandler: args.EnableEpochsHandler, + }, nil +} + +// RequestDataFromHash requests data from other peers by having a hash and the epoch as input +func (requester *equivalentProofsRequester) RequestDataFromHash(hash []byte, epoch uint32) error { + if !requester.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, epoch) { + return nil + } + + return requester.SendOnRequestTopic( + &dataRetriever.RequestData{ + Type: dataRetriever.HashType, + Value: hash, + Epoch: epoch, + }, + [][]byte{hash}, + ) +} + +// RequestDataFromHashArray requests equivalent proofs data from other peers by having multiple header hashes and the epoch as input +// all headers must be from the same epoch +func (requester *equivalentProofsRequester) RequestDataFromHashArray(hashes [][]byte, epoch uint32) error { + if !requester.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, epoch) { + return nil + } + + return requester.requestDataFromHashArray(hashes, epoch) +} + +// RequestDataFromNonce requests equivalent proofs data from other peers for the specified nonce-shard key +func (requester *equivalentProofsRequester) RequestDataFromNonce(nonceShardKey []byte, epoch uint32) error { + if !requester.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, epoch) { + return nil + } + + log.Trace("equivalentProofsRequester.RequestDataFromNonce", + "nonce-shard", string(nonceShardKey), + "epoch", epoch, + "topic", requester.RequestTopic()) + + return requester.SendOnRequestTopic( + &dataRetriever.RequestData{ + Type: dataRetriever.NonceType, + Value: nonceShardKey, + Epoch: epoch, + }, + [][]byte{nonceShardKey}, + ) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (requester *equivalentProofsRequester) IsInterfaceNil() bool { + return requester == nil +} diff --git a/dataRetriever/requestHandlers/requesters/requesters_test.go b/dataRetriever/requestHandlers/requesters/requesters_test.go index 4ec7ec9a74e..fe58ffd72be 100644 --- a/dataRetriever/requestHandlers/requesters/requesters_test.go +++ b/dataRetriever/requestHandlers/requesters/requesters_test.go @@ -2,6 +2,8 @@ package requesters import ( "errors" + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "testing" "github.com/multiversx/mx-chain-core-go/core/check" @@ -21,23 +23,27 @@ const ( txRequester requestHandlerType = "transactionRequester" trieRequester requestHandlerType = "trieNodeRequester" vInfoRequester requestHandlerType = "validatorInfoNodeRequester" + eqProofsRequester requestHandlerType = "equivalentProofsRequester" ) var expectedErr = errors.New("expected error") func Test_Requesters(t *testing.T) { t.Parallel() + testNewRequester(t, peerAuthRequester) testNewRequester(t, mbRequester) testNewRequester(t, txRequester) testNewRequester(t, trieRequester) testNewRequester(t, vInfoRequester) + testNewRequester(t, eqProofsRequester) testRequestDataFromHashArray(t, peerAuthRequester) testRequestDataFromHashArray(t, mbRequester) testRequestDataFromHashArray(t, txRequester) testRequestDataFromHashArray(t, trieRequester) testRequestDataFromHashArray(t, vInfoRequester) + testRequestDataFromHashArray(t, eqProofsRequester) testRequestDataFromReferenceAndChunk(t, trieRequester) } @@ -147,6 +153,15 @@ func getHandler(requesterType requestHandlerType, argsBase ArgBaseRequester) (ch return NewTrieNodeRequester(ArgTrieNodeRequester{argsBase}) case vInfoRequester: return NewValidatorInfoRequester(ArgValidatorInfoRequester{argsBase}) + case eqProofsRequester: + return NewEquivalentProofsRequester(ArgEquivalentProofsRequester{ + ArgBaseRequester: argsBase, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + }, + }) } return nil, errors.New("invalid requester type") } diff --git a/dataRetriever/resolvers/baseFullHistoryResolver.go b/dataRetriever/resolvers/baseFullHistoryResolver.go index 0372777f646..7c84ebcf205 100644 --- a/dataRetriever/resolvers/baseFullHistoryResolver.go +++ b/dataRetriever/resolvers/baseFullHistoryResolver.go @@ -9,7 +9,7 @@ type baseFullHistoryResolver struct { } func (bfhr *baseFullHistoryResolver) getFromStorage(key []byte, epoch uint32) ([]byte, error) { - //we just call the storer to search in the provided epoch. (it will search automatically also in the next epoch) + // we just call the storer to search in the provided epoch. (it will search automatically also in the next epoch) buff, err := bfhr.storer.GetFromEpoch(key, epoch) if err != nil { // default to a search first, maximize the chance of getting recent data diff --git a/dataRetriever/resolvers/disabled/resolver.go b/dataRetriever/resolvers/disabled/resolver.go index ac51a954260..c031ac2de48 100644 --- a/dataRetriever/resolvers/disabled/resolver.go +++ b/dataRetriever/resolvers/disabled/resolver.go @@ -2,6 +2,7 @@ package disabled import ( "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/p2p" ) @@ -15,8 +16,8 @@ func NewDisabledResolver() *resolver { } // ProcessReceivedMessage returns nil as it is disabled -func (r *resolver) ProcessReceivedMessage(_ p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) error { - return nil +func (r *resolver) ProcessReceivedMessage(_ p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) ([]byte, error) { + return []byte{}, nil } // SetDebugHandler returns nil as it is disabled diff --git a/dataRetriever/resolvers/equivalentProofsResolver.go b/dataRetriever/resolvers/equivalentProofsResolver.go new file mode 100644 index 00000000000..c36c3e9ac92 --- /dev/null +++ b/dataRetriever/resolvers/equivalentProofsResolver.go @@ -0,0 +1,257 @@ +package resolvers + +import ( + "fmt" + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data/batch" + "github.com/multiversx/mx-chain-core-go/data/typeConverters" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/dataRetriever" + "github.com/multiversx/mx-chain-go/p2p" + "github.com/multiversx/mx-chain-go/process/interceptors/processor" + "github.com/multiversx/mx-chain-go/storage" + logger "github.com/multiversx/mx-chain-logger-go" +) + +// maxBuffToSendEquivalentProofs represents max buffer size to send in bytes +const maxBuffToSendEquivalentProofs = 1 << 18 // 256KB + +// ArgEquivalentProofsResolver is the argument structure used to create a new equivalent proofs resolver instance +type ArgEquivalentProofsResolver struct { + ArgBaseResolver + DataPacker dataRetriever.DataPacker + Storage dataRetriever.StorageService + EquivalentProofsPool processor.EquivalentProofsPool + NonceConverter typeConverters.Uint64ByteSliceConverter + IsFullHistoryNode bool +} + +type equivalentProofsResolver struct { + *baseResolver + baseStorageResolver + messageProcessor + dataPacker dataRetriever.DataPacker + storage dataRetriever.StorageService + equivalentProofsPool processor.EquivalentProofsPool + nonceConverter typeConverters.Uint64ByteSliceConverter +} + +// NewEquivalentProofsResolver creates an equivalent proofs resolver +func NewEquivalentProofsResolver(args ArgEquivalentProofsResolver) (*equivalentProofsResolver, error) { + err := checkArgEquivalentProofsResolver(args) + if err != nil { + return nil, err + } + + equivalentProofsStorage, err := args.Storage.GetStorer(dataRetriever.ProofsUnit) + if err != nil { + return nil, err + } + + return &equivalentProofsResolver{ + baseResolver: &baseResolver{ + TopicResolverSender: args.SenderResolver, + }, + baseStorageResolver: createBaseStorageResolver(equivalentProofsStorage, args.IsFullHistoryNode), + messageProcessor: messageProcessor{ + marshalizer: args.Marshaller, + antifloodHandler: args.AntifloodHandler, + throttler: args.Throttler, + topic: args.SenderResolver.RequestTopic(), + }, + dataPacker: args.DataPacker, + storage: args.Storage, + equivalentProofsPool: args.EquivalentProofsPool, + nonceConverter: args.NonceConverter, + }, nil +} + +func checkArgEquivalentProofsResolver(args ArgEquivalentProofsResolver) error { + err := checkArgBase(args.ArgBaseResolver) + if err != nil { + return err + } + if check.IfNil(args.DataPacker) { + return dataRetriever.ErrNilDataPacker + } + if check.IfNil(args.Storage) { + return dataRetriever.ErrNilStore + } + if check.IfNil(args.EquivalentProofsPool) { + return dataRetriever.ErrNilProofsPool + } + if check.IfNil(args.NonceConverter) { + return dataRetriever.ErrNilUint64ByteSliceConverter + } + + return nil +} + +// ProcessReceivedMessage represents the callback func from the p2p.Messenger that is called each time a new message is received +// (for the topic this validator was registered to, usually a request topic) +func (res *equivalentProofsResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) { + err := res.canProcessMessage(message, fromConnectedPeer) + if err != nil { + return nil, err + } + + res.throttler.StartProcessing() + defer res.throttler.EndProcessing() + + rd, err := res.parseReceivedMessage(message, fromConnectedPeer) + if err != nil { + return nil, err + } + + switch rd.Type { + case dataRetriever.HashType: + return nil, res.resolveHashRequest(rd.Value, rd.Epoch, message.Peer(), source) + case dataRetriever.HashArrayType: + return nil, res.resolveMultipleHashesRequest(rd.Value, rd.Epoch, message.Peer(), source) + case dataRetriever.NonceType: + return nil, res.resolveNonceRequest(rd.Value, rd.Epoch, message.Peer(), source) + default: + err = dataRetriever.ErrRequestTypeNotImplemented + } + if err != nil { + return nil, fmt.Errorf("%w for value %s", err, logger.DisplayByteSlice(rd.Value)) + } + + return []byte{}, nil +} + +// resolveHashRequest sends the response for a hash request +func (res *equivalentProofsResolver) resolveHashRequest(hashShardKey []byte, epoch uint32, pid core.PeerID, source p2p.MessageHandler) error { + headerHash, shardID, err := common.GetHashAndShardFromKey(hashShardKey) + if err != nil { + return fmt.Errorf("resolveHashRequest.getHashAndShard error %w", err) + } + + data, err := res.fetchEquivalentProofAsByteSlice(headerHash, shardID, epoch) + if err != nil { + return fmt.Errorf("resolveHashRequest.fetchEquivalentProofAsByteSlice error %w", err) + } + + return res.Send(data, pid, source) +} + +// resolveMultipleHashesRequest sends the response for multiple hashes request +func (res *equivalentProofsResolver) resolveMultipleHashesRequest(hashShardKeysBuff []byte, epoch uint32, pid core.PeerID, source p2p.MessageHandler) error { + b := batch.Batch{} + err := res.marshalizer.Unmarshal(&b, hashShardKeysBuff) + if err != nil { + return err + } + hashShardKeys := b.Data + + equivalentProofsForHashes, err := res.fetchEquivalentProofsSlicesForHeaders(hashShardKeys, epoch) + if err != nil { + return fmt.Errorf("resolveMultipleHashesRequest.fetchEquivalentProofsSlicesForHeaders error %w", err) + } + + return res.sendEquivalentProofsForHashes(equivalentProofsForHashes, pid, source) +} + +// resolveNonceRequest sends the response for a nonce request +func (res *equivalentProofsResolver) resolveNonceRequest(nonceShardKey []byte, epoch uint32, pid core.PeerID, source p2p.MessageHandler) error { + data, err := res.fetchEquivalentProofFromNonceAsByteSlice(nonceShardKey, epoch) + if err != nil { + return fmt.Errorf("resolveNonceRequest.fetchEquivalentProofFromNonceAsByteSlice error %w", err) + } + + return res.Send(data, pid, source) +} + +// sendEquivalentProofsForHashes sends multiple equivalent proofs for specific hashes +func (res *equivalentProofsResolver) sendEquivalentProofsForHashes(dataBuff [][]byte, pid core.PeerID, source p2p.MessageHandler) error { + buffsToSend, err := res.dataPacker.PackDataInChunks(dataBuff, maxBuffToSendEquivalentProofs) + if err != nil { + return err + } + + for _, buff := range buffsToSend { + err = res.Send(buff, pid, source) + if err != nil { + return err + } + } + + return nil +} + +// fetchEquivalentProofsSlicesForHeaders fetches all equivalent proofs for the given header hashes +func (res *equivalentProofsResolver) fetchEquivalentProofsSlicesForHeaders(hashShardKeys [][]byte, epoch uint32) ([][]byte, error) { + equivalentProofs := make([][]byte, 0) + for _, hashShardKey := range hashShardKeys { + headerHash, shardID, err := common.GetHashAndShardFromKey(hashShardKey) + if err != nil { + return nil, err + } + + equivalentProofForHash, _ := res.fetchEquivalentProofAsByteSlice(headerHash, shardID, epoch) + if equivalentProofForHash != nil { + equivalentProofs = append(equivalentProofs, equivalentProofForHash) + } + } + + if len(equivalentProofs) == 0 { + return nil, dataRetriever.ErrEquivalentProofsNotFound + } + + return equivalentProofs, nil +} + +// fetchEquivalentProofAsByteSlice returns the value from equivalent proofs pool or storage if exists +func (res *equivalentProofsResolver) fetchEquivalentProofAsByteSlice(headerHash []byte, shardID uint32, epoch uint32) ([]byte, error) { + proof, err := res.equivalentProofsPool.GetProof(shardID, headerHash) + if err != nil { + return res.getFromStorage(headerHash, epoch) + } + + return res.marshalizer.Marshal(proof) +} + +// fetchEquivalentProofFromNonceAsByteSlice returns the value from equivalent proofs pool or storage if exists +func (res *equivalentProofsResolver) fetchEquivalentProofFromNonceAsByteSlice(nonceShardKey []byte, epoch uint32) ([]byte, error) { + headerNonce, shardID, err := common.GetNonceAndShardFromKey(nonceShardKey) + if err != nil { + return nil, fmt.Errorf("fetchEquivalentProofFromNonceAsByteSlice.getNonceAndShard error %w", err) + } + + proof, err := res.equivalentProofsPool.GetProofByNonce(headerNonce, shardID) + if err != nil { + return res.getProofFromStorageByNonce(headerNonce, shardID, epoch) + } + + return res.marshalizer.Marshal(proof) +} + +// getProofFromStorageByNonce returns the value from equivalent storage if exists +func (res *equivalentProofsResolver) getProofFromStorageByNonce(headerNonce uint64, shardID uint32, epoch uint32) ([]byte, error) { + storer, err := res.getStorerForShard(shardID) + if err != nil { + return nil, err + } + + nonceBytes := res.nonceConverter.ToByteSlice(headerNonce) + headerHash, err := storer.SearchFirst(nonceBytes) + if err != nil { + return nil, err + } + + return res.getFromStorage(headerHash, epoch) +} + +func (res *equivalentProofsResolver) getStorerForShard(shardID uint32) (storage.Storer, error) { + if shardID == core.MetachainShardId { + return res.storage.GetStorer(dataRetriever.MetaHdrNonceHashDataUnit) + } + + return res.storage.GetStorer(dataRetriever.GetHdrNonceHashDataUnit(shardID)) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (res *equivalentProofsResolver) IsInterfaceNil() bool { + return res == nil +} diff --git a/dataRetriever/resolvers/equivalentProofsResolver_test.go b/dataRetriever/resolvers/equivalentProofsResolver_test.go new file mode 100644 index 00000000000..6539617067a --- /dev/null +++ b/dataRetriever/resolvers/equivalentProofsResolver_test.go @@ -0,0 +1,818 @@ +package resolvers_test + +import ( + "encoding/hex" + "errors" + "fmt" + "math/big" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/batch" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/dataRetriever" + "github.com/multiversx/mx-chain-go/dataRetriever/mock" + "github.com/multiversx/mx-chain-go/dataRetriever/resolvers" + "github.com/multiversx/mx-chain-go/p2p" + "github.com/multiversx/mx-chain-go/storage" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" + storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" + "github.com/stretchr/testify/require" +) + +var ( + providedHashKey = []byte(fmt.Sprintf("%s-0", hex.EncodeToString([]byte("hash")))) + providedNonceKey = []byte("1-1") +) + +func createMockArgEquivalentProofsResolver() resolvers.ArgEquivalentProofsResolver { + return resolvers.ArgEquivalentProofsResolver{ + ArgBaseResolver: createMockArgBaseResolver(), + DataPacker: &mock.DataPackerStub{}, + Storage: &storageStubs.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + return &storageStubs.StorerStub{}, nil + }, + }, + EquivalentProofsPool: &dataRetrieverMocks.ProofsPoolMock{}, + NonceConverter: &mock.Uint64ByteSliceConverterMock{ + ToByteSliceCalled: func(u uint64) []byte { + return big.NewInt(0).SetUint64(u).Bytes() + }, + }, + IsFullHistoryNode: false, + } +} + +func createMockRequestedProofsBuff() ([]byte, error) { + marshaller := &marshal.GogoProtoMarshalizer{} + + return marshaller.Marshal(&batch.Batch{Data: [][]byte{[]byte("proof")}}) +} + +func TestNewEquivalentProofsResolver(t *testing.T) { + t.Parallel() + + t.Run("nil SenderResolver should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.SenderResolver = nil + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Equal(t, dataRetriever.ErrNilResolverSender, err) + require.Nil(t, res) + }) + t.Run("nil DataPacker should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.DataPacker = nil + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Equal(t, dataRetriever.ErrNilDataPacker, err) + require.Nil(t, res) + }) + t.Run("nil Storage should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.Storage = nil + res, err := resolvers.NewEquivalentProofsResolver(args) + require.True(t, errors.Is(err, dataRetriever.ErrNilStore)) + require.Nil(t, res) + }) + t.Run("nil NonceConverter should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.NonceConverter = nil + res, err := resolvers.NewEquivalentProofsResolver(args) + require.True(t, errors.Is(err, dataRetriever.ErrNilUint64ByteSliceConverter)) + require.Nil(t, res) + }) + t.Run("nil EquivalentProofsPool should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.EquivalentProofsPool = nil + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Equal(t, dataRetriever.ErrNilProofsPool, err) + require.Nil(t, res) + }) + t.Run("error on GetStorer should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.Storage = &storageStubs.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + return nil, expectedErr + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Equal(t, expectedErr, err) + require.Nil(t, res) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + res, err := resolvers.NewEquivalentProofsResolver(createMockArgEquivalentProofsResolver()) + require.NoError(t, err) + require.NotNil(t, res) + }) +} + +func TestEquivalentProofsResolver_IsInterfaceNil(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.EquivalentProofsPool = nil + res, _ := resolvers.NewEquivalentProofsResolver(args) + require.True(t, res.IsInterfaceNil()) + + res, _ = resolvers.NewEquivalentProofsResolver(createMockArgEquivalentProofsResolver()) + require.False(t, res.IsInterfaceNil()) +} + +func TestEquivalentProofsResolver_ProcessReceivedMessage(t *testing.T) { + t.Parallel() + + t.Run("nil message should error", func(t *testing.T) { + t.Parallel() + + res, _ := resolvers.NewEquivalentProofsResolver(createMockArgEquivalentProofsResolver()) + + _, err := res.ProcessReceivedMessage(nil, "", nil) + require.Equal(t, dataRetriever.ErrNilMessage, err) + }) + t.Run("parseReceivedMessage returns error due to marshaller error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.Marshaller = &mock.MarshalizerStub{ + UnmarshalCalled: func(obj interface{}, buff []byte) error { + return expectedErr + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, nil), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.True(t, errors.Is(err, expectedErr)) + require.Nil(t, msgID) + }) + t.Run("invalid request type should error", func(t *testing.T) { + t.Parallel() + + requestedBuff, err := createMockRequestedProofsBuff() + require.Nil(t, err) + + args := createMockArgEquivalentProofsResolver() + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.ChunkType, requestedBuff), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.True(t, errors.Is(err, dataRetriever.ErrRequestTypeNotImplemented)) + require.Nil(t, msgID) + }) + t.Run("resolveHashRequest: marshal failure before send should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + require.Equal(t, []byte("hash"), headerHash) + + return &block.HeaderProof{}, nil + }, + } + mockMarshaller := &marshallerMock.MarshalizerMock{} + args.Marshaller = &marshallerMock.MarshalizerStub{ + MarshalCalled: func(obj interface{}) ([]byte, error) { + return nil, expectedErr + }, + UnmarshalCalled: func(obj interface{}, buff []byte) error { + return mockMarshaller.Unmarshal(obj, buff) + }, + } + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + require.Fail(t, "should have not been called") + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, providedHashKey), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.True(t, errors.Is(err, expectedErr)) + require.Nil(t, msgID) + }) + t.Run("resolveHashRequest: invalid key should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + require.Fail(t, "should have not been called") + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + // invalid format + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, []byte("invalidKey")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.Error(t, err) + require.Nil(t, msgID) + + // invalid shard + msgID, err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, []byte("hash_notAShard")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.Error(t, err) + require.Nil(t, msgID) + }) + t.Run("resolveHashRequest: hash not found anywhere should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + wasGetProofByHashCalled := false + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + wasGetProofByHashCalled = true + require.Equal(t, []byte("hash"), headerHash) + + return nil, expectedErr + }, + } + wasSearchFirstCalled := false + args.Storage = &storageStubs.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + return &storageStubs.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + wasSearchFirstCalled = true + require.Equal(t, []byte("hash"), key) + + return nil, expectedErr + }, + }, nil + }, + } + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + require.Fail(t, "should have not been called") + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, providedHashKey), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.True(t, errors.Is(err, expectedErr)) + require.Nil(t, msgID) + require.True(t, wasGetProofByHashCalled) + require.True(t, wasSearchFirstCalled) + }) + t.Run("resolveHashRequest: should work and return from pool", func(t *testing.T) { + t.Parallel() + + providedProof := &block.HeaderProof{} + args := createMockArgEquivalentProofsResolver() + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + require.Equal(t, []byte("hash"), headerHash) + + return providedProof, nil + }, + } + wasSendCalled := false + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + wasSendCalled = true + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, providedHashKey), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.NoError(t, err) + require.Nil(t, msgID) + require.True(t, wasSendCalled) + }) + t.Run("resolveHashRequest: should work and return from storage", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + require.Equal(t, []byte("hash"), headerHash) + + return nil, expectedErr + }, + } + args.Storage = &storageStubs.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + return &storageStubs.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + require.Equal(t, []byte("hash"), key) + + return []byte("proof"), nil + }, + }, nil + }, + } + wasSendCalled := false + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + wasSendCalled = true + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, providedHashKey), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.NoError(t, err) + require.Nil(t, msgID) + require.True(t, wasSendCalled) + }) + t.Run("resolveMultipleHashesRequest: hashes unmarshall error should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + require.Fail(t, "should have not been called") + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, []byte("invalid data")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.Error(t, err) + require.Nil(t, msgID) + }) + t.Run("resolveMultipleHashesRequest: invalid key should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + require.Fail(t, "should have not been called") + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + providedHashKeyes, err := args.Marshaller.Marshal(batch.Batch{Data: [][]byte{[]byte("invalidKey")}}) + require.Nil(t, err) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashKeyes), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.Error(t, err) + require.Nil(t, msgID) + }) + t.Run("resolveMultipleHashesRequest: hash not found anywhere should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + wasGetProofByHashCalled := false + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + wasGetProofByHashCalled = true + require.Equal(t, []byte("hash"), headerHash) + + return nil, expectedErr + }, + } + wasSearchFirstCalled := false + args.Storage = &storageStubs.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + return &storageStubs.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + wasSearchFirstCalled = true + require.Equal(t, []byte("hash"), key) + + return nil, expectedErr + }, + }, nil + }, + } + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + require.Fail(t, "should have not been called") + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + providedHashKeyes, err := args.Marshaller.Marshal(batch.Batch{Data: [][]byte{providedHashKey}}) + require.Nil(t, err) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashKeyes), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.True(t, errors.Is(err, dataRetriever.ErrEquivalentProofsNotFound)) + require.Nil(t, msgID) + require.True(t, wasGetProofByHashCalled) + require.True(t, wasSearchFirstCalled) + }) + t.Run("resolveMultipleHashesRequest: PackDataInChunks error should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + require.Equal(t, []byte("hash"), headerHash) + + return &block.HeaderProof{}, nil + }, + } + args.DataPacker = &mock.DataPackerStub{ + PackDataInChunksCalled: func(data [][]byte, limit int) ([][]byte, error) { + return nil, expectedErr + }, + } + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + require.Fail(t, "should have not been called") + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + providedHashKeyes, err := args.Marshaller.Marshal(batch.Batch{Data: [][]byte{providedHashKey}}) + require.Nil(t, err) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashKeyes), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.Equal(t, expectedErr, err) + require.Nil(t, msgID) + }) + t.Run("resolveMultipleHashesRequest: Send error should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + require.Equal(t, []byte("hash"), headerHash) + + return &block.HeaderProof{}, nil + }, + } + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + return expectedErr + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + providedHashKeyes, err := args.Marshaller.Marshal(batch.Batch{Data: [][]byte{providedHashKey}}) + require.Nil(t, err) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashKeyes), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.Equal(t, expectedErr, err) + require.Nil(t, msgID) + }) + t.Run("resolveMultipleHashesRequest: one hash should work and return from pool", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + require.Equal(t, []byte("hash"), headerHash) + + return &block.HeaderProof{}, nil + }, + } + wasSendCalled := false + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + wasSendCalled = true + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + providedHashKeyes, err := args.Marshaller.Marshal(batch.Batch{Data: [][]byte{providedHashKey}}) + require.Nil(t, err) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashKeyes), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.NoError(t, err) + require.Nil(t, msgID) + require.True(t, wasSendCalled) + }) + t.Run("resolveMultipleHashesRequest: one hash should work and return from storage", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + require.Equal(t, []byte("hash"), headerHash) + + return nil, expectedErr + }, + } + args.Storage = &storageStubs.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + return &storageStubs.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + require.Equal(t, []byte("hash"), key) + + return []byte("proof"), nil + }, + }, nil + }, + } + wasSendCalled := false + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + wasSendCalled = true + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + providedHashKeyes, err := args.Marshaller.Marshal(batch.Batch{Data: [][]byte{providedHashKey}}) + require.Nil(t, err) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashKeyes), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.NoError(t, err) + require.Nil(t, msgID) + require.True(t, wasSendCalled) + }) + t.Run("resolveMultipleHashesRequest: one hash in pool, one in storage should work", func(t *testing.T) { + t.Parallel() + + providedHashKey2 := []byte(fmt.Sprintf("%s-2", hex.EncodeToString([]byte("hash2")))) + args := createMockArgEquivalentProofsResolver() + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + if string(headerHash) == "hash" { + return &block.HeaderProof{}, nil + } + return nil, expectedErr + }, + } + args.Storage = &storageStubs.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + return &storageStubs.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + if string(key) == "hash2" { + return []byte("proof"), nil + } + return nil, expectedErr + }, + }, nil + }, + } + cntSendCalled := 0 + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + cntSendCalled++ + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + providedHashKeyes, err := args.Marshaller.Marshal(batch.Batch{Data: [][]byte{providedHashKey, providedHashKey2}}) + require.Nil(t, err) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashKeyes), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.NoError(t, err) + require.Nil(t, msgID) + require.Equal(t, 2, cntSendCalled) + }) + t.Run("resolveNonceRequest: marshal failure of proof should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofByNonceCalled: func(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) { + require.Equal(t, uint64(1), headerNonce) + require.Equal(t, uint32(1), shardID) + + return &block.HeaderProof{}, nil + }, + } + mockMarshaller := &marshallerMock.MarshalizerMock{} + args.Marshaller = &marshallerMock.MarshalizerStub{ + MarshalCalled: func(obj interface{}) ([]byte, error) { + return nil, expectedErr + }, + UnmarshalCalled: func(obj interface{}, buff []byte) error { + return mockMarshaller.Unmarshal(obj, buff) + }, + } + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + require.Fail(t, "should have not been called") + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, providedNonceKey), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.True(t, errors.Is(err, expectedErr)) + require.Nil(t, msgID) + }) + t.Run("resolveNonceRequest: invalid key should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + require.Fail(t, "should have not been called") + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + // invalid format + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, []byte("invalidkey")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.Error(t, err) + require.Nil(t, msgID) + + // invalid nonce + msgID, err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, []byte("notANonce_0")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.Error(t, err) + require.Nil(t, msgID) + + // invalid shard + msgID, err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, []byte("0_notAShard")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.Error(t, err) + require.Nil(t, msgID) + }) + t.Run("resolveNonceRequest: error on nonceHashStorage should error", func(t *testing.T) { + t.Parallel() + + providedMetaNonceKey := fmt.Sprintf("%d-%d", 1, core.MetachainShardId) // meta for coverage + args := createMockArgEquivalentProofsResolver() + wasGetProofByNonceCalled := false + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofByNonceCalled: func(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) { + wasGetProofByNonceCalled = true + require.Equal(t, uint64(1), headerNonce) + require.Equal(t, core.MetachainShardId, shardID) + + return nil, expectedErr + }, + } + args.Storage = &storageStubs.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + return &storageStubs.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + return nil, expectedErr + }, + }, nil + }, + } + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + require.Fail(t, "should have not been called") + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, []byte(providedMetaNonceKey)), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.True(t, errors.Is(err, expectedErr)) + require.Nil(t, msgID) + require.True(t, wasGetProofByNonceCalled) + }) + t.Run("resolveNonceRequest: nonce not found anywhere should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + wasGetProofByNonceCalled := false + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofByNonceCalled: func(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) { + wasGetProofByNonceCalled = true + require.Equal(t, uint64(1), headerNonce) + require.Equal(t, uint32(1), shardID) + + return nil, expectedErr + }, + } + wasSearchFirstCalled := false + args.Storage = &storageStubs.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + return &storageStubs.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + wasSearchFirstCalled = true + + return nil, expectedErr + }, + }, nil + }, + } + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + require.Fail(t, "should have not been called") + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, providedNonceKey), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.True(t, errors.Is(err, expectedErr)) + require.Nil(t, msgID) + require.True(t, wasGetProofByNonceCalled) + require.True(t, wasSearchFirstCalled) + }) + t.Run("resolveNonceRequest: should work and return from pool", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofByNonceCalled: func(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) { + require.Equal(t, uint64(1), headerNonce) + require.Equal(t, uint32(1), shardID) + + return &block.HeaderProof{}, nil + }, + } + wasSendCalled := false + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + wasSendCalled = true + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, providedNonceKey), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.NoError(t, err) + require.Nil(t, msgID) + require.True(t, wasSendCalled) + }) + t.Run("resolveNonceRequest: should work and return from storage", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsResolver() + args.EquivalentProofsPool = &dataRetrieverMocks.ProofsPoolMock{ + GetProofByNonceCalled: func(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) { + require.Equal(t, uint64(1), headerNonce) + require.Equal(t, uint32(1), shardID) + + return nil, expectedErr + }, + } + args.Storage = &storageStubs.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + expectedUnitType := dataRetriever.GetHdrNonceHashDataUnit(1) + if unitType == expectedUnitType { + return &storageStubs.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + return []byte("hash"), nil + }, + }, nil + } + + return &storageStubs.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + require.Equal(t, []byte("hash"), key) + + return []byte("proof"), nil + }, + }, nil + }, + } + wasSendCalled := false + args.SenderResolver = &mock.TopicResolverSenderStub{ + SendCalled: func(buff []byte, peer core.PeerID, source p2p.MessageHandler) error { + wasSendCalled = true + + return nil + }, + } + res, err := resolvers.NewEquivalentProofsResolver(args) + require.Nil(t, err) + + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, providedNonceKey), fromConnectedPeer, &p2pmocks.MessengerStub{}) + require.NoError(t, err) + require.Nil(t, msgID) + require.True(t, wasSendCalled) + }) +} diff --git a/dataRetriever/resolvers/headerResolver.go b/dataRetriever/resolvers/headerResolver.go index 877c57a31da..dbd8626bf3a 100644 --- a/dataRetriever/resolvers/headerResolver.go +++ b/dataRetriever/resolvers/headerResolver.go @@ -6,12 +6,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/typeConverters" + "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/resolvers/epochproviders/disabled" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/storage" - "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("dataRetriever/resolvers") @@ -109,10 +110,10 @@ func (hdrRes *HeaderResolver) SetEpochHandler(epochHandler dataRetriever.EpochHa // ProcessReceivedMessage will be the callback func from the p2p.Messenger and will be called each time a new message was received // (for the topic this validator was registered to, usually a request topic) -func (hdrRes *HeaderResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error { +func (hdrRes *HeaderResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) { err := hdrRes.canProcessMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } hdrRes.throttler.StartProcessing() @@ -120,7 +121,7 @@ func (hdrRes *HeaderResolver) ProcessReceivedMessage(message p2p.MessageP2P, fro rd, err := hdrRes.parseReceivedMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } var buff []byte @@ -133,7 +134,7 @@ func (hdrRes *HeaderResolver) ProcessReceivedMessage(message p2p.MessageP2P, fro case dataRetriever.EpochType: buff, err = hdrRes.resolveHeaderFromEpoch(rd.Value) default: - return dataRetriever.ErrResolveTypeUnknown + return nil, dataRetriever.ErrResolveTypeUnknown } if err != nil { hdrRes.DebugHandler().LogFailedToResolveData( @@ -141,7 +142,7 @@ func (hdrRes *HeaderResolver) ProcessReceivedMessage(message p2p.MessageP2P, fro rd.Value, err, ) - return err + return nil, err } if buff == nil { @@ -153,12 +154,15 @@ func (hdrRes *HeaderResolver) ProcessReceivedMessage(message p2p.MessageP2P, fro log.Trace("missing data", "data", rd) - return nil + return []byte{}, nil } hdrRes.DebugHandler().LogSucceededToResolveData(hdrRes.topic, rd.Value) - - return hdrRes.Send(buff, message.Peer(), source) + err = hdrRes.Send(buff, message.Peer(), source) + if err != nil { + return nil, err + } + return []byte{}, nil } func (hdrRes *HeaderResolver) resolveHeaderFromNonce(rd *dataRetriever.RequestData) ([]byte, error) { diff --git a/dataRetriever/resolvers/headerResolver_test.go b/dataRetriever/resolvers/headerResolver_test.go index f50606a244e..61aebfa2a02 100644 --- a/dataRetriever/resolvers/headerResolver_test.go +++ b/dataRetriever/resolvers/headerResolver_test.go @@ -12,13 +12,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/dataRetriever/resolvers" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" ) func createMockArgBaseResolver() resolvers.ArgBaseResolver { @@ -165,10 +166,11 @@ func TestHeaderResolver_ProcessReceivedCanProcessMessageErrorsShouldErr(t *testi } hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, nil), fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, nil), fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestHeaderResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) { @@ -177,10 +179,11 @@ func TestHeaderResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) { arg := createMockArgHeaderResolver() hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, nil), fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, nil), fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, dataRetriever.ErrNilValue, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestHeaderResolver_ProcessReceivedMessage_WrongIdentifierStartBlock(t *testing.T) { @@ -190,10 +193,11 @@ func TestHeaderResolver_ProcessReceivedMessage_WrongIdentifierStartBlock(t *test hdrRes, _ := resolvers.NewHeaderResolver(arg) requestedData := []byte("request") - err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.EpochType, requestedData), "", &p2pmocks.MessengerStub{}) + msgID, err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.EpochType, requestedData), "", &p2pmocks.MessengerStub{}) assert.Equal(t, core.ErrInvalidIdentifierForEpochStartBlockRequest, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestHeaderResolver_ProcessReceivedMessageEpochTypeUnknownEpochShouldWork(t *testing.T) { @@ -215,11 +219,12 @@ func TestHeaderResolver_ProcessReceivedMessageEpochTypeUnknownEpochShouldWork(t hdrRes, _ := resolvers.NewHeaderResolver(arg) requestedData := []byte(fmt.Sprintf("epoch_%d", math.MaxUint32)) - err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.EpochType, requestedData), "", &p2pmocks.MessengerStub{}) + msgID, err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.EpochType, requestedData), "", &p2pmocks.MessengerStub{}) assert.NoError(t, err) assert.True(t, wasSent) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestHeaderResolver_ProcessReceivedMessage_Ok(t *testing.T) { @@ -234,10 +239,11 @@ func TestHeaderResolver_ProcessReceivedMessage_Ok(t *testing.T) { hdrRes, _ := resolvers.NewHeaderResolver(arg) requestedData := []byte("request_1") - err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.EpochType, requestedData), "", &p2pmocks.MessengerStub{}) + msgID, err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.EpochType, requestedData), "", &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestHeaderResolver_ProcessReceivedMessageRequestUnknownTypeShouldErr(t *testing.T) { @@ -246,17 +252,17 @@ func TestHeaderResolver_ProcessReceivedMessageRequestUnknownTypeShouldErr(t *tes arg := createMockArgHeaderResolver() hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage(createRequestMsg(254, make([]byte, 0)), fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := hdrRes.ProcessReceivedMessage(createRequestMsg(254, make([]byte, 0)), fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, dataRetriever.ErrResolveTypeUnknown, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestHeaderResolver_ValidateRequestHashTypeFoundInHdrPoolShouldSearchAndSend(t *testing.T) { t.Parallel() requestedData := []byte("aaaa") - searchWasCalled := false sendWasCalled := false @@ -280,12 +286,13 @@ func TestHeaderResolver_ValidateRequestHashTypeFoundInHdrPoolShouldSearchAndSend arg.Headers = headers hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, requestedData), fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, requestedData), fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, searchWasCalled) assert.True(t, sendWasCalled) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestHeaderResolver_ValidateRequestHashTypeFoundInHdrPoolShouldSearchAndSendFullHistory(t *testing.T) { @@ -317,12 +324,13 @@ func TestHeaderResolver_ValidateRequestHashTypeFoundInHdrPoolShouldSearchAndSend arg.Headers = headers hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, requestedData), fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, requestedData), fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, searchWasCalled) assert.True(t, sendWasCalled) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestHeaderResolver_ProcessReceivedMessageRequestHashTypeFoundInHdrPoolMarshalizerFailsShouldErr(t *testing.T) { @@ -360,10 +368,11 @@ func TestHeaderResolver_ProcessReceivedMessageRequestHashTypeFoundInHdrPoolMarsh arg.Headers = headers hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, requestedData), fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, requestedData), fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, errExpected, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestHeaderResolver_ProcessReceivedMessageRequestRetFromStorageShouldRetValAndSend(t *testing.T) { @@ -400,12 +409,13 @@ func TestHeaderResolver_ProcessReceivedMessageRequestRetFromStorageShouldRetValA arg.HdrStorage = store hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, requestedData), fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, requestedData), fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, wasGotFromStorage) assert.True(t, wasSent) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeInvalidSliceShouldErr(t *testing.T) { @@ -419,10 +429,11 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeInvalidSliceShould } hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, []byte("aaa")), fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, []byte("aaa")), fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, dataRetriever.ErrInvalidNonceByteSlice, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestHeaderResolver_ProcessReceivedMessageRequestNonceShouldCallWithTheCorrectEpoch(t *testing.T) { @@ -446,7 +457,7 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceShouldCallWithTheCorre }, ) msg := &p2pmocks.P2PMessageMock{DataField: buff} - _ = hdrRes.ProcessReceivedMessage(msg, "", &p2pmocks.MessengerStub{}) + _, _ = hdrRes.ProcessReceivedMessage(msg, "", &p2pmocks.MessengerStub{}) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) } @@ -488,7 +499,7 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeNotFoundInHdrNonce } hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage( + msgID, err := hdrRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.NonceType, arg.NonceConverter.ToByteSlice(requestedNonce)), fromConnectedPeerId, &p2pmocks.MessengerStub{}, @@ -497,6 +508,7 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeNotFoundInHdrNonce assert.False(t, wasSent) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoolShouldRetFromPoolAndSend(t *testing.T) { @@ -534,7 +546,7 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoo } hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage( + msgID, err := hdrRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.NonceType, arg.NonceConverter.ToByteSlice(requestedNonce)), fromConnectedPeerId, &p2pmocks.MessengerStub{}, @@ -545,6 +557,7 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoo assert.True(t, wasSent) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoolShouldRetFromStorageAndSend(t *testing.T) { @@ -597,7 +610,7 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoo } hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage( + msgID, err := hdrRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.NonceType, arg.NonceConverter.ToByteSlice(requestedNonce)), fromConnectedPeerId, &p2pmocks.MessengerStub{}, @@ -608,6 +621,7 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoo assert.True(t, wasSend) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoolButMarshalFailsShouldError(t *testing.T) { @@ -654,7 +668,7 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoo } hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage( + msgID, err := hdrRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.NonceType, arg.NonceConverter.ToByteSlice(requestedNonce)), fromConnectedPeerId, &p2pmocks.MessengerStub{}, @@ -664,6 +678,7 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoo assert.True(t, wasResolved) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeNotFoundInHdrNoncePoolShouldRetFromPoolAndSend(t *testing.T) { @@ -696,7 +711,7 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeNotFoundInHdrNonce } hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage( + msgID, err := hdrRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.NonceType, arg.NonceConverter.ToByteSlice(requestedNonce)), fromConnectedPeerId, &p2pmocks.MessengerStub{}, @@ -706,6 +721,7 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeNotFoundInHdrNonce assert.True(t, wasSend) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoolCheckRetErr(t *testing.T) { @@ -753,7 +769,7 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoo } hdrRes, _ := resolvers.NewHeaderResolver(arg) - err := hdrRes.ProcessReceivedMessage( + msgID, err := hdrRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.NonceType, arg.NonceConverter.ToByteSlice(requestedNonce)), fromConnectedPeerId, &p2pmocks.MessengerStub{}, @@ -762,6 +778,7 @@ func TestHeaderResolver_ProcessReceivedMessageRequestNonceTypeFoundInHdrNoncePoo assert.Equal(t, errExpected, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestHeaderResolver_SetEpochHandlerNilShouldErr(t *testing.T) { @@ -816,8 +833,9 @@ func TestHeaderResolver_SetEpochHandlerConcurrency(t *testing.T) { assert.Nil(t, err) return } - err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.EpochType, []byte("request_1")), fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := hdrRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.EpochType, []byte("request_1")), fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Nil(t, err) + assert.Len(t, msgID, 0) }(i) } wg.Wait() diff --git a/dataRetriever/resolvers/miniblockResolver.go b/dataRetriever/resolvers/miniblockResolver.go index 0c1a1460074..3fb74105af5 100644 --- a/dataRetriever/resolvers/miniblockResolver.go +++ b/dataRetriever/resolvers/miniblockResolver.go @@ -6,10 +6,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/storage" - logger "github.com/multiversx/mx-chain-logger-go" ) var _ dataRetriever.Resolver = (*miniblockResolver)(nil) @@ -77,10 +78,10 @@ func checkArgMiniblockResolver(arg ArgMiniblockResolver) error { // ProcessReceivedMessage will be the callback func from the p2p.Messenger and will be called each time a new message was received // (for the topic this validator was registered to, usually a request topic) -func (mbRes *miniblockResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error { +func (mbRes *miniblockResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) { err := mbRes.canProcessMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } mbRes.throttler.StartProcessing() @@ -88,7 +89,7 @@ func (mbRes *miniblockResolver) ProcessReceivedMessage(message p2p.MessageP2P, f rd, err := mbRes.parseReceivedMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } switch rd.Type { @@ -101,10 +102,10 @@ func (mbRes *miniblockResolver) ProcessReceivedMessage(message p2p.MessageP2P, f } if err != nil { - err = fmt.Errorf("%w for hash %s", err, logger.DisplayByteSlice(rd.Value)) + return nil, fmt.Errorf("%w for hash %s", err, logger.DisplayByteSlice(rd.Value)) } - return err + return []byte{}, nil } func (mbRes *miniblockResolver) resolveMbRequestByHash(hash []byte, pid core.PeerID, epoch uint32, source p2p.MessageHandler) error { diff --git a/dataRetriever/resolvers/miniblockResolver_test.go b/dataRetriever/resolvers/miniblockResolver_test.go index 35588e9d6a9..c86fc0c3989 100644 --- a/dataRetriever/resolvers/miniblockResolver_test.go +++ b/dataRetriever/resolvers/miniblockResolver_test.go @@ -9,14 +9,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/dataRetriever/resolvers" "github.com/multiversx/mx-chain-go/p2p" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" ) var fromConnectedPeerId = core.PeerID("from connected peer Id") @@ -24,7 +25,7 @@ var fromConnectedPeerId = core.PeerID("from connected peer Id") func createMockArgMiniblockResolver() resolvers.ArgMiniblockResolver { return resolvers.ArgMiniblockResolver{ ArgBaseResolver: createMockArgBaseResolver(), - MiniBlockPool: testscommon.NewCacherStub(), + MiniBlockPool: cache.NewCacherStub(), MiniBlockStorage: &storageStubs.StorerStub{}, DataPacker: &mock.DataPackerStub{}, } @@ -128,10 +129,11 @@ func TestMiniblockResolver_ProcessReceivedAntifloodErrorsShouldErr(t *testing.T) } mbRes, _ := resolvers.NewMiniblockResolver(arg) - err := mbRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, nil), fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mbRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, nil), fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestMiniblockResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) { @@ -140,10 +142,11 @@ func TestMiniblockResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) arg := createMockArgMiniblockResolver() mbRes, _ := resolvers.NewMiniblockResolver(arg) - err := mbRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, nil), fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mbRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, nil), fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, dataRetriever.ErrNilValue, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestMiniblockResolver_ProcessReceivedMessageWrongTypeShouldErr(t *testing.T) { @@ -152,11 +155,12 @@ func TestMiniblockResolver_ProcessReceivedMessageWrongTypeShouldErr(t *testing.T arg := createMockArgMiniblockResolver() mbRes, _ := resolvers.NewMiniblockResolver(arg) - err := mbRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, make([]byte, 0)), fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mbRes.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, make([]byte, 0)), fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, dataRetriever.ErrRequestTypeNotImplemented)) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolShouldRetValAndSend(t *testing.T) { @@ -173,7 +177,7 @@ func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolShouldRetValAndSend( wasResolved := false wasSent := false - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { if bytes.Equal(key, mbHash) { wasResolved = true @@ -198,7 +202,7 @@ func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolShouldRetValAndSend( } mbRes, _ := resolvers.NewMiniblockResolver(arg) - err := mbRes.ProcessReceivedMessage( + msgID, err := mbRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.HashArrayType, requestedBuff), fromConnectedPeerId, &p2pmocks.MessengerStub{}, @@ -209,6 +213,7 @@ func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolShouldRetValAndSend( assert.True(t, wasSent) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolMarshalizerFailShouldErr(t *testing.T) { @@ -232,7 +237,7 @@ func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolMarshalizerFailShoul assert.Nil(t, merr) - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { if bytes.Equal(key, mbHash) { return &block.MiniBlock{}, true @@ -253,7 +258,7 @@ func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolMarshalizerFailShoul arg.Marshaller = marshalizer mbRes, _ := resolvers.NewMiniblockResolver(arg) - err := mbRes.ProcessReceivedMessage( + msgID, err := mbRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.HashArrayType, requestedBuff), fromConnectedPeerId, &p2pmocks.MessengerStub{}, @@ -262,6 +267,7 @@ func TestMiniblockResolver_ProcessReceivedMessageFoundInPoolMarshalizerFailShoul assert.True(t, errors.Is(err, errExpected)) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestMiniblockResolver_ProcessReceivedMessageUnmarshalFails(t *testing.T) { @@ -286,7 +292,7 @@ func TestMiniblockResolver_ProcessReceivedMessageUnmarshalFails(t *testing.T) { assert.Nil(t, merr) - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } @@ -309,7 +315,7 @@ func TestMiniblockResolver_ProcessReceivedMessageUnmarshalFails(t *testing.T) { } mbRes, _ := resolvers.NewMiniblockResolver(arg) - err := mbRes.ProcessReceivedMessage( + msgID, err := mbRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.HashArrayType, requestedBuff), fromConnectedPeerId, &p2pmocks.MessengerStub{}, @@ -318,6 +324,7 @@ func TestMiniblockResolver_ProcessReceivedMessageUnmarshalFails(t *testing.T) { assert.True(t, errors.Is(err, expectedErr)) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestMiniblockResolver_ProcessReceivedMessagePackDataInChunksFails(t *testing.T) { @@ -331,7 +338,7 @@ func TestMiniblockResolver_ProcessReceivedMessagePackDataInChunksFails(t *testin assert.Nil(t, merr) - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } @@ -353,7 +360,7 @@ func TestMiniblockResolver_ProcessReceivedMessagePackDataInChunksFails(t *testin } mbRes, _ := resolvers.NewMiniblockResolver(arg) - err := mbRes.ProcessReceivedMessage( + msgID, err := mbRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.HashArrayType, requestedBuff), fromConnectedPeerId, &p2pmocks.MessengerStub{}, @@ -362,6 +369,7 @@ func TestMiniblockResolver_ProcessReceivedMessagePackDataInChunksFails(t *testin assert.True(t, errors.Is(err, expectedErr)) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestMiniblockResolver_ProcessReceivedMessageSendFails(t *testing.T) { @@ -375,7 +383,7 @@ func TestMiniblockResolver_ProcessReceivedMessageSendFails(t *testing.T) { assert.Nil(t, merr) - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } @@ -397,7 +405,7 @@ func TestMiniblockResolver_ProcessReceivedMessageSendFails(t *testing.T) { } mbRes, _ := resolvers.NewMiniblockResolver(arg) - err := mbRes.ProcessReceivedMessage( + msgID, err := mbRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.HashArrayType, requestedBuff), fromConnectedPeerId, &p2pmocks.MessengerStub{}, @@ -406,6 +414,7 @@ func TestMiniblockResolver_ProcessReceivedMessageSendFails(t *testing.T) { assert.True(t, errors.Is(err, expectedErr)) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestMiniblockResolver_ProcessReceivedMessageNotFoundInPoolShouldRetFromStorageAndSend(t *testing.T) { @@ -420,7 +429,7 @@ func TestMiniblockResolver_ProcessReceivedMessageNotFoundInPoolShouldRetFromStor wasResolved := false wasSend := false - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } @@ -443,7 +452,7 @@ func TestMiniblockResolver_ProcessReceivedMessageNotFoundInPoolShouldRetFromStor arg.MiniBlockStorage = store mbRes, _ := resolvers.NewMiniblockResolver(arg) - err := mbRes.ProcessReceivedMessage( + msgID, err := mbRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.HashType, requestedBuff), fromConnectedPeerId, &p2pmocks.MessengerStub{}, @@ -454,6 +463,7 @@ func TestMiniblockResolver_ProcessReceivedMessageNotFoundInPoolShouldRetFromStor assert.True(t, wasSend) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestMiniblockResolver_ProcessReceivedMessageMarshalFails(t *testing.T) { @@ -467,7 +477,7 @@ func TestMiniblockResolver_ProcessReceivedMessageMarshalFails(t *testing.T) { wasResolved := false - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } @@ -496,7 +506,7 @@ func TestMiniblockResolver_ProcessReceivedMessageMarshalFails(t *testing.T) { } mbRes, _ := resolvers.NewMiniblockResolver(arg) - err := mbRes.ProcessReceivedMessage( + msgID, err := mbRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.HashType, requestedBuff), fromConnectedPeerId, &p2pmocks.MessengerStub{}, @@ -506,6 +516,7 @@ func TestMiniblockResolver_ProcessReceivedMessageMarshalFails(t *testing.T) { assert.True(t, wasResolved) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestMiniblockResolver_ProcessReceivedMessageMissingDataShouldNotSend(t *testing.T) { @@ -519,7 +530,7 @@ func TestMiniblockResolver_ProcessReceivedMessageMissingDataShouldNotSend(t *tes wasSent := false - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } @@ -540,7 +551,7 @@ func TestMiniblockResolver_ProcessReceivedMessageMissingDataShouldNotSend(t *tes arg.MiniBlockStorage = store mbRes, _ := resolvers.NewMiniblockResolver(arg) - _ = mbRes.ProcessReceivedMessage( + _, _ = mbRes.ProcessReceivedMessage( createRequestMsg(dataRetriever.HashType, requestedBuff), fromConnectedPeerId, &p2pmocks.MessengerStub{}, diff --git a/dataRetriever/resolvers/peerAuthenticationResolver.go b/dataRetriever/resolvers/peerAuthenticationResolver.go index dc2a45892c2..49f29ff0246 100644 --- a/dataRetriever/resolvers/peerAuthenticationResolver.go +++ b/dataRetriever/resolvers/peerAuthenticationResolver.go @@ -6,11 +6,12 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/storage" - logger "github.com/multiversx/mx-chain-logger-go" ) // maxBuffToSendPeerAuthentications represents max buffer size to send in bytes @@ -75,10 +76,10 @@ func checkArgPeerAuthenticationResolver(arg ArgPeerAuthenticationResolver) error // ProcessReceivedMessage represents the callback func from the p2p.Messenger that is called each time a new message is received // (for the topic this validator was registered to, usually a request topic) -func (res *peerAuthenticationResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error { +func (res *peerAuthenticationResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) { err := res.canProcessMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } res.throttler.StartProcessing() @@ -86,20 +87,20 @@ func (res *peerAuthenticationResolver) ProcessReceivedMessage(message p2p.Messag rd, err := res.parseReceivedMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } switch rd.Type { case dataRetriever.HashArrayType: - return res.resolveMultipleHashesRequest(rd.Value, message.Peer(), source) + return nil, res.resolveMultipleHashesRequest(rd.Value, message.Peer(), source) default: err = dataRetriever.ErrRequestTypeNotImplemented } if err != nil { - err = fmt.Errorf("%w for value %s", err, logger.DisplayByteSlice(rd.Value)) + return nil, fmt.Errorf("%w for value %s", err, logger.DisplayByteSlice(rd.Value)) } - return err + return []byte{}, nil } // resolveMultipleHashesRequest sends the response for multiple hashes request diff --git a/dataRetriever/resolvers/peerAuthenticationResolver_test.go b/dataRetriever/resolvers/peerAuthenticationResolver_test.go index 188c29d7e3f..4918a5f041d 100644 --- a/dataRetriever/resolvers/peerAuthenticationResolver_test.go +++ b/dataRetriever/resolvers/peerAuthenticationResolver_test.go @@ -13,15 +13,17 @@ import ( "github.com/multiversx/mx-chain-core-go/core/partitioning" "github.com/multiversx/mx-chain-core-go/data/batch" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/dataRetriever/resolvers" "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var expectedErr = errors.New("expected error") @@ -57,7 +59,7 @@ func createMockPeerAuthenticationObject() interface{} { func createMockArgPeerAuthenticationResolver() resolvers.ArgPeerAuthenticationResolver { return resolvers.ArgPeerAuthenticationResolver{ ArgBaseResolver: createMockArgBaseResolver(), - PeerAuthenticationPool: testscommon.NewCacherStub(), + PeerAuthenticationPool: cache.NewCacherStub(), DataPacker: &mock.DataPackerStub{}, PayloadValidator: &testscommon.PeerAuthenticationPayloadValidatorStub{}, } @@ -164,8 +166,9 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { assert.Nil(t, err) assert.False(t, res.IsInterfaceNil()) - err = res.ProcessReceivedMessage(nil, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(nil, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Equal(t, dataRetriever.ErrNilMessage, err) + assert.Nil(t, msgID) }) t.Run("canProcessMessage due to antiflood handler error", func(t *testing.T) { t.Parallel() @@ -180,10 +183,11 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { assert.Nil(t, err) assert.False(t, res.IsInterfaceNil()) - err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.ChunkType, nil), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.ChunkType, nil), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) }) t.Run("parseReceivedMessage returns error due to marshaller error", func(t *testing.T) { t.Parallel() @@ -198,8 +202,9 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { assert.Nil(t, err) assert.False(t, res.IsInterfaceNil()) - err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.ChunkType, nil), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.ChunkType, nil), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) + assert.Nil(t, msgID) }) t.Run("invalid request type should error", func(t *testing.T) { t.Parallel() @@ -213,8 +218,9 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { assert.Nil(t, err) assert.False(t, res.IsInterfaceNil()) - err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, requestedBuff), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, requestedBuff), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, dataRetriever.ErrRequestTypeNotImplemented)) + assert.Nil(t, msgID) }) // =============== HashArrayType -> resolveMultipleHashesRequest =============== @@ -227,13 +233,14 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { assert.Nil(t, err) assert.False(t, res.IsInterfaceNil()) - err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, []byte("invalid data")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, []byte("invalid data")), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.NotNil(t, err) + assert.Nil(t, msgID) }) t.Run("resolveMultipleHashesRequest: all hashes missing from cache should error", func(t *testing.T) { t.Parallel() - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return nil, false } @@ -254,15 +261,16 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { hashes := getKeysSlice() providedHashes, err := arg.Marshaller.Marshal(batch.Batch{Data: hashes}) assert.Nil(t, err) - err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashes), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashes), fromConnectedPeer, &p2pmocks.MessengerStub{}) expectedSubstrErr := fmt.Sprintf("%s %x", "from buff", providedHashes) assert.True(t, strings.Contains(fmt.Sprintf("%s", err), expectedSubstrErr)) assert.False(t, wasSent) + assert.Nil(t, msgID) }) t.Run("resolveMultipleHashesRequest: all hashes will return wrong objects should error", func(t *testing.T) { t.Parallel() - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return "wrong object", true } @@ -283,16 +291,17 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { hashes := getKeysSlice() providedHashes, err := arg.Marshaller.Marshal(batch.Batch{Data: hashes}) assert.Nil(t, err) - err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashes), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashes), fromConnectedPeer, &p2pmocks.MessengerStub{}) expectedSubstrErr := fmt.Sprintf("%s %x", "from buff", providedHashes) assert.True(t, strings.Contains(fmt.Sprintf("%s", err), expectedSubstrErr)) assert.False(t, wasSent) + assert.Nil(t, msgID) }) t.Run("resolveMultipleHashesRequest: all hashes will return objects with invalid payload should error", func(t *testing.T) { t.Parallel() arg := createMockArgPeerAuthenticationResolver() - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return createMockPeerAuthenticationObject(), true } @@ -319,11 +328,12 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { hashes := getKeysSlice() providedHashes, err := arg.Marshaller.Marshal(batch.Batch{Data: hashes}) assert.Nil(t, err) - err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashes), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashes), fromConnectedPeer, &p2pmocks.MessengerStub{}) expectedSubstrErr := fmt.Sprintf("%s %x", "from buff", providedHashes) assert.True(t, strings.Contains(fmt.Sprintf("%s", err), expectedSubstrErr)) assert.False(t, wasSent) assert.Equal(t, len(hashes), numValidationCalls) + assert.Nil(t, msgID) }) t.Run("resolveMultipleHashesRequest: some data missing from cache should work", func(t *testing.T) { t.Parallel() @@ -349,7 +359,7 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { providedHashes, err := arg.Marshaller.Marshal(batch.Batch{Data: hashes}) assert.Nil(t, err) - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { val, ok := providedKeys[string(key)] return val, ok @@ -387,14 +397,15 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { assert.Nil(t, err) assert.False(t, res.IsInterfaceNil()) - err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashes), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashes), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, wasSent) + assert.Nil(t, msgID) }) t.Run("resolveMultipleHashesRequest: PackDataInChunks returns error", func(t *testing.T) { t.Parallel() - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return createMockPeerAuthenticationObject(), true } @@ -413,13 +424,14 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { hashes := getKeysSlice() providedHashes, err := arg.Marshaller.Marshal(batch.Batch{Data: hashes}) assert.Nil(t, err) - err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashes), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashes), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) + assert.Nil(t, msgID) }) t.Run("resolveMultipleHashesRequest: Send returns error", func(t *testing.T) { t.Parallel() - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { return createMockPeerAuthenticationObject(), true } @@ -438,15 +450,16 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { hashes := getKeysSlice() providedHashes, err := arg.Marshaller.Marshal(batch.Batch{Data: hashes}) assert.Nil(t, err) - err = res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashes), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, providedHashes), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) + assert.Nil(t, msgID) }) t.Run("resolveMultipleHashesRequest: send large data buff", func(t *testing.T) { t.Parallel() providedKeys := getKeysSlice() expectedLen := len(providedKeys) - cache := testscommon.NewCacherStub() + cache := cache.NewCacherStub() cache.PeekCalled = func(key []byte) (value interface{}, ok bool) { for _, pk := range providedKeys { if bytes.Equal(pk, key) { @@ -501,9 +514,10 @@ func TestPeerAuthenticationResolver_ProcessReceivedMessage(t *testing.T) { chunkIndex := uint32(0) providedHashes, err := arg.Marshaller.Marshal(&batch.Batch{Data: providedKeys}) assert.Nil(t, err) - err = res.ProcessReceivedMessage(createRequestMsgWithChunkIndex(dataRetriever.HashArrayType, providedHashes, epoch, chunkIndex), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsgWithChunkIndex(dataRetriever.HashArrayType, providedHashes, epoch, chunkIndex), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.Equal(t, 2, messagesSent) assert.Equal(t, expectedLen, hashesReceived) + assert.Nil(t, msgID) }) } diff --git a/dataRetriever/resolvers/transactionResolver.go b/dataRetriever/resolvers/transactionResolver.go index 3a88bd13c15..8495c970a70 100644 --- a/dataRetriever/resolvers/transactionResolver.go +++ b/dataRetriever/resolvers/transactionResolver.go @@ -6,10 +6,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/storage" - logger "github.com/multiversx/mx-chain-logger-go" ) var _ dataRetriever.Resolver = (*TxResolver)(nil) @@ -82,10 +83,10 @@ func checkArgTxResolver(arg ArgTxResolver) error { // ProcessReceivedMessage will be the callback func from the p2p.Messenger and will be called each time a new message was received // (for the topic this validator was registered to, usually a request topic) -func (txRes *TxResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error { +func (txRes *TxResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) { err := txRes.canProcessMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } txRes.throttler.StartProcessing() @@ -93,7 +94,7 @@ func (txRes *TxResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConn rd, err := txRes.parseReceivedMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } switch rd.Type { @@ -106,10 +107,10 @@ func (txRes *TxResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConn } if err != nil { - err = fmt.Errorf("%w for hash %s", err, logger.DisplayByteSlice(rd.Value)) + return nil, fmt.Errorf("%w for hash %s", err, logger.DisplayByteSlice(rd.Value)) } - return err + return []byte{}, nil } func (txRes *TxResolver) resolveTxRequestByHash(hash []byte, pid core.PeerID, epoch uint32, source p2p.MessageHandler) error { diff --git a/dataRetriever/resolvers/transactionResolver_test.go b/dataRetriever/resolvers/transactionResolver_test.go index 2af167aae70..7d7679b9b78 100644 --- a/dataRetriever/resolvers/transactionResolver_test.go +++ b/dataRetriever/resolvers/transactionResolver_test.go @@ -9,6 +9,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/dataRetriever/resolvers" @@ -16,7 +18,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" ) var connectedPeerId = core.PeerID("connected peer id") @@ -131,11 +132,12 @@ func TestTxResolver_ProcessReceivedMessageCanProcessMessageErrorsShouldErr(t *te } txRes, _ := resolvers.NewTxResolver(arg) - err := txRes.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{}, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{}, connectedPeerId, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTxResolver_ProcessReceivedMessageNilMessageShouldErr(t *testing.T) { @@ -144,11 +146,12 @@ func TestTxResolver_ProcessReceivedMessageNilMessageShouldErr(t *testing.T) { arg := createMockArgTxResolver() txRes, _ := resolvers.NewTxResolver(arg) - err := txRes.ProcessReceivedMessage(nil, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(nil, connectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, dataRetriever.ErrNilMessage, err) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTxResolver_ProcessReceivedMessageWrongTypeShouldErr(t *testing.T) { @@ -161,11 +164,12 @@ func TestTxResolver_ProcessReceivedMessageWrongTypeShouldErr(t *testing.T) { msg := &p2pmocks.P2PMessageMock{DataField: data} - err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, dataRetriever.ErrRequestTypeNotImplemented)) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTxResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) { @@ -178,11 +182,12 @@ func TestTxResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) { msg := &p2pmocks.P2PMessageMock{DataField: data} - err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, dataRetriever.ErrNilValue, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTxResolver_ProcessReceivedMessageFoundInTxPoolShouldSearchAndSend(t *testing.T) { @@ -218,13 +223,14 @@ func TestTxResolver_ProcessReceivedMessageFoundInTxPoolShouldSearchAndSend(t *te msg := &p2pmocks.P2PMessageMock{DataField: data} - err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, searchWasCalled) assert.True(t, sendWasCalled) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestTxResolver_ProcessReceivedMessageFoundInTxPoolMarshalizerFailShouldRetNilAndErr(t *testing.T) { @@ -262,11 +268,12 @@ func TestTxResolver_ProcessReceivedMessageFoundInTxPoolMarshalizerFailShouldRetN msg := &p2pmocks.P2PMessageMock{DataField: data} - err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, errExpected)) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTxResolver_ProcessReceivedMessageBatchMarshalFailShouldRetNilAndErr(t *testing.T) { @@ -307,11 +314,12 @@ func TestTxResolver_ProcessReceivedMessageBatchMarshalFailShouldRetNilAndErr(t * msg := &p2pmocks.P2PMessageMock{DataField: data} - err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTxResolver_ProcessReceivedMessageFoundInTxStorageShouldRetValAndSend(t *testing.T) { @@ -321,7 +329,7 @@ func TestTxResolver_ProcessReceivedMessageFoundInTxStorageShouldRetValAndSend(t txPool := testscommon.NewShardedDataStub() txPool.SearchFirstDataCalled = func(key []byte) (value interface{}, ok bool) { - //not found in txPool + // not found in txPool return nil, false } searchWasCalled := false @@ -355,13 +363,14 @@ func TestTxResolver_ProcessReceivedMessageFoundInTxStorageShouldRetValAndSend(t msg := &p2pmocks.P2PMessageMock{DataField: data} - err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, searchWasCalled) assert.True(t, sendWasCalled) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestTxResolver_ProcessReceivedMessageFoundInTxStorageCheckRetError(t *testing.T) { @@ -371,7 +380,7 @@ func TestTxResolver_ProcessReceivedMessageFoundInTxStorageCheckRetError(t *testi txPool := testscommon.NewShardedDataStub() txPool.SearchFirstDataCalled = func(key []byte) (value interface{}, ok bool) { - //not found in txPool + // not found in txPool return nil, false } @@ -395,11 +404,12 @@ func TestTxResolver_ProcessReceivedMessageFoundInTxStorageCheckRetError(t *testi msg := &p2pmocks.P2PMessageMock{DataField: data} - err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, errExpected)) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTxResolver_ProcessReceivedMessageRequestedTwoSmallTransactionsShouldCallSliceSplitter(t *testing.T) { @@ -455,13 +465,14 @@ func TestTxResolver_ProcessReceivedMessageRequestedTwoSmallTransactionsShouldCal msg := &p2pmocks.P2PMessageMock{DataField: data} - err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, splitSliceWasCalled) assert.True(t, sendWasCalled) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestTxResolver_ProcessReceivedMessageRequestedTwoSmallTransactionsFoundOnlyOneShouldWork(t *testing.T) { @@ -516,13 +527,14 @@ func TestTxResolver_ProcessReceivedMessageRequestedTwoSmallTransactionsFoundOnly msg := &p2pmocks.P2PMessageMock{DataField: data} - err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) assert.NotNil(t, err) assert.True(t, splitSliceWasCalled) assert.True(t, sendWasCalled) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTxResolver_ProcessReceivedMessageHashArrayUnmarshalFails(t *testing.T) { @@ -545,11 +557,12 @@ func TestTxResolver_ProcessReceivedMessageHashArrayUnmarshalFails(t *testing.T) data, _ := marshalizer.Marshal(&dataRetriever.RequestData{Type: dataRetriever.HashArrayType, Value: []byte("buff")}) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTxResolver_ProcessReceivedMessageHashArrayPackDataInChunksFails(t *testing.T) { @@ -570,11 +583,12 @@ func TestTxResolver_ProcessReceivedMessageHashArrayPackDataInChunksFails(t *test data, _ := arg.Marshaller.Marshal(&dataRetriever.RequestData{Type: dataRetriever.HashArrayType, Value: buff}) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTxResolver_ProcessReceivedMessageHashArraySendFails(t *testing.T) { @@ -595,11 +609,12 @@ func TestTxResolver_ProcessReceivedMessageHashArraySendFails(t *testing.T) { data, _ := arg.Marshaller.Marshal(&dataRetriever.RequestData{Type: dataRetriever.HashArrayType, Value: buff}) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := txRes.ProcessReceivedMessage(msg, connectedPeerId, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTxResolver_Close(t *testing.T) { diff --git a/dataRetriever/resolvers/trieNodeResolver.go b/dataRetriever/resolvers/trieNodeResolver.go index 275327d44c6..1d2936a2eda 100644 --- a/dataRetriever/resolvers/trieNodeResolver.go +++ b/dataRetriever/resolvers/trieNodeResolver.go @@ -6,9 +6,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/p2p" - logger "github.com/multiversx/mx-chain-logger-go" ) var _ dataRetriever.Resolver = (*TrieNodeResolver)(nil) @@ -62,10 +63,10 @@ func checkArgTrieNodeResolver(arg ArgTrieNodeResolver) error { // ProcessReceivedMessage will be the callback func from the p2p.Messenger and will be called each time a new message was received // (for the topic this validator was registered to, usually a request topic) -func (tnRes *TrieNodeResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error { +func (tnRes *TrieNodeResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) { err := tnRes.canProcessMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } tnRes.throttler.StartProcessing() @@ -73,17 +74,22 @@ func (tnRes *TrieNodeResolver) ProcessReceivedMessage(message p2p.MessageP2P, fr rd, err := tnRes.parseReceivedMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } switch rd.Type { case dataRetriever.HashType: - return tnRes.resolveOneHash(rd.Value, rd.ChunkIndex, message, source) + err = tnRes.resolveOneHash(rd.Value, rd.ChunkIndex, message, source) case dataRetriever.HashArrayType: - return tnRes.resolveMultipleHashes(rd.Value, message, source) + err = tnRes.resolveMultipleHashes(rd.Value, message, source) default: - return dataRetriever.ErrRequestTypeNotImplemented + err = dataRetriever.ErrRequestTypeNotImplemented + } + + if err != nil { + return nil, err } + return []byte{}, nil } func (tnRes *TrieNodeResolver) resolveMultipleHashes(hashesBuff []byte, message p2p.MessageP2P, source p2p.MessageHandler) error { @@ -214,7 +220,7 @@ func (tnRes *TrieNodeResolver) sendResponse( ) error { if len(serializedNodes) == 0 { - //do not send useless message + // do not send useless message return nil } diff --git a/dataRetriever/resolvers/trieNodeResolver_test.go b/dataRetriever/resolvers/trieNodeResolver_test.go index b2706f02b36..b988b2f2959 100644 --- a/dataRetriever/resolvers/trieNodeResolver_test.go +++ b/dataRetriever/resolvers/trieNodeResolver_test.go @@ -10,14 +10,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/dataRetriever/resolvers" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var fromConnectedPeer = core.PeerID("from connected peer") @@ -108,10 +109,11 @@ func TestTrieNodeResolver_ProcessReceivedAntiflooderCanProcessMessageErrShouldEr } tnRes, _ := resolvers.NewTrieNodeResolver(arg) - err := tnRes.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{}, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := tnRes.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{}, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTrieNodeResolver_ProcessReceivedMessageNilMessageShouldErr(t *testing.T) { @@ -120,10 +122,11 @@ func TestTrieNodeResolver_ProcessReceivedMessageNilMessageShouldErr(t *testing.T arg := createMockArgTrieNodeResolver() tnRes, _ := resolvers.NewTrieNodeResolver(arg) - err := tnRes.ProcessReceivedMessage(nil, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := tnRes.ProcessReceivedMessage(nil, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Equal(t, dataRetriever.ErrNilMessage, err) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.False(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTrieNodeResolver_ProcessReceivedMessageWrongTypeShouldErr(t *testing.T) { @@ -137,10 +140,11 @@ func TestTrieNodeResolver_ProcessReceivedMessageWrongTypeShouldErr(t *testing.T) data, _ := marshalizer.Marshal(&dataRetriever.RequestData{Type: dataRetriever.NonceType, Value: []byte("aaa")}) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Equal(t, dataRetriever.ErrRequestTypeNotImplemented, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTrieNodeResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) { @@ -154,13 +158,14 @@ func TestTrieNodeResolver_ProcessReceivedMessageNilValueShouldErr(t *testing.T) data, _ := marshalizer.Marshal(&dataRetriever.RequestData{Type: dataRetriever.HashType, Value: nil}) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Equal(t, dataRetriever.ErrNilValue, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } -//TODO in this PR: add more unit tests +// TODO in this PR: add more unit tests func TestTrieNodeResolver_ProcessReceivedMessageShouldGetFromTrieAndSend(t *testing.T) { t.Parallel() @@ -193,13 +198,14 @@ func TestTrieNodeResolver_ProcessReceivedMessageShouldGetFromTrieAndSend(t *test data, _ := marshalizer.Marshal(&dataRetriever.RequestData{Type: dataRetriever.HashType, Value: []byte("node1")}) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, getSerializedNodesWasCalled) assert.True(t, sendWasCalled) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestTrieNodeResolver_ProcessReceivedMessageShouldGetFromTrieAndMarshalizerFailShouldRetNilAndErr(t *testing.T) { @@ -223,10 +229,11 @@ func TestTrieNodeResolver_ProcessReceivedMessageShouldGetFromTrieAndMarshalizerF data, _ := marshalizerMock.Marshal(&dataRetriever.RequestData{Type: dataRetriever.HashType, Value: []byte("node1")}) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Equal(t, errExpected, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTrieNodeResolver_ProcessReceivedMessageTrieErrorsShouldErr(t *testing.T) { @@ -243,10 +250,11 @@ func TestTrieNodeResolver_ProcessReceivedMessageTrieErrorsShouldErr(t *testing.T data, _ := arg.Marshaller.Marshal(&dataRetriever.RequestData{Type: dataRetriever.HashType, Value: []byte("node1")}) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Equal(t, expectedErr, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesUnmarshalFails(t *testing.T) { @@ -286,10 +294,11 @@ func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesUnmarshalFails(t * ) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Equal(t, expectedErr, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) } func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesGetSerializedNodeErrorsShouldNotSend(t *testing.T) { @@ -322,10 +331,11 @@ func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesGetSerializedNodeE ) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Len(t, msgID, 0) } func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesGetSerializedNodesErrorsShouldNotSendSubtrie(t *testing.T) { @@ -375,12 +385,13 @@ func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesGetSerializedNodes ) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) require.Equal(t, 1, len(receivedNodes)) assert.Equal(t, nodes[0], receivedNodes[0]) + assert.Len(t, msgID, 0) } func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesNotEnoughSpaceShouldNotReadSubtries(t *testing.T) { @@ -431,12 +442,13 @@ func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesNotEnoughSpaceShou ) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) require.Equal(t, 1, len(receivedNodes)) assert.Equal(t, nodes[0], receivedNodes[0]) + assert.Len(t, msgID, 0) } func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesShouldWorkWithSubtries(t *testing.T) { @@ -492,11 +504,12 @@ func TestTrieNodeResolver_ProcessReceivedMessageMultipleHashesShouldWorkWithSubt ) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) require.Equal(t, 4, len(receivedNodes)) + assert.Len(t, msgID, 0) for _, n := range nodes { assert.True(t, buffInSlice(n, receivedNodes)) } @@ -558,17 +571,18 @@ func testTrieNodeResolverProcessReceivedMessageLargeTrieNode( ) msg := &p2pmocks.P2PMessageMock{DataField: data} - err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := tnRes.ProcessReceivedMessage(msg, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.True(t, arg.Throttler.(*mock.ThrottlerStub).EndWasCalled()) require.True(t, sendWasCalled) + assert.Len(t, msgID, 0) } func TestTrieNodeResolver_ProcessReceivedMessageLargeTrieNodeShouldSendFirstChunk(t *testing.T) { t.Parallel() - randBuff := make([]byte, 1<<20) //1MB + randBuff := make([]byte, 1<<20) // 1MB _, _ = rand.Read(randBuff) testTrieNodeResolverProcessReceivedMessageLargeTrieNode(t, randBuff, 0, 4, 0, core.MaxBufferSizeToSendTrieNodes) } @@ -576,7 +590,7 @@ func TestTrieNodeResolver_ProcessReceivedMessageLargeTrieNodeShouldSendFirstChun func TestTrieNodeResolver_ProcessReceivedMessageLargeTrieNodeShouldSendRequiredChunk(t *testing.T) { t.Parallel() - randBuff := make([]byte, 1<<20) //1MB + randBuff := make([]byte, 1<<20) // 1MB _, _ = rand.Read(randBuff) testTrieNodeResolverProcessReceivedMessageLargeTrieNode( t, @@ -603,7 +617,7 @@ func TestTrieNodeResolver_ProcessReceivedMessageLargeTrieNodeShouldSendRequiredC 4*core.MaxBufferSizeToSendTrieNodes, ) - randBuff = make([]byte, 1<<20+1) //1MB + 1 byte + randBuff = make([]byte, 1<<20+1) // 1MB + 1 byte _, _ = rand.Read(randBuff) startIndex := len(randBuff) - 1 endIndex := len(randBuff) diff --git a/dataRetriever/resolvers/validatorInfoResolver.go b/dataRetriever/resolvers/validatorInfoResolver.go index 9f7e5a6bb1a..65255b8ad8f 100644 --- a/dataRetriever/resolvers/validatorInfoResolver.go +++ b/dataRetriever/resolvers/validatorInfoResolver.go @@ -8,10 +8,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/storage" - logger "github.com/multiversx/mx-chain-logger-go" ) // maxBuffToSendValidatorsInfo represents max buffer size to send in bytes @@ -89,10 +90,10 @@ func checkArgs(args ArgValidatorInfoResolver) error { // ProcessReceivedMessage represents the callback func from the p2p.Messenger that is called each time a new message is received // (for the topic this validator was registered to, usually a request topic) -func (res *validatorInfoResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error { +func (res *validatorInfoResolver) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) { err := res.canProcessMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } res.throttler.StartProcessing() @@ -100,17 +101,22 @@ func (res *validatorInfoResolver) ProcessReceivedMessage(message p2p.MessageP2P, rd, err := res.parseReceivedMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } switch rd.Type { case dataRetriever.HashType: - return res.resolveHashRequest(rd.Value, rd.Epoch, fromConnectedPeer, source) + err = res.resolveHashRequest(rd.Value, rd.Epoch, fromConnectedPeer, source) case dataRetriever.HashArrayType: - return res.resolveMultipleHashesRequest(rd.Value, rd.Epoch, fromConnectedPeer, source) + err = res.resolveMultipleHashesRequest(rd.Value, rd.Epoch, fromConnectedPeer, source) + default: + err = fmt.Errorf("%w for value %s", dataRetriever.ErrRequestTypeNotImplemented, logger.DisplayByteSlice(rd.Value)) } - return fmt.Errorf("%w for value %s", dataRetriever.ErrRequestTypeNotImplemented, logger.DisplayByteSlice(rd.Value)) + if err != nil { + return nil, err + } + return []byte{}, nil } // resolveHashRequest sends the response for a hash request diff --git a/dataRetriever/resolvers/validatorInfoResolver_test.go b/dataRetriever/resolvers/validatorInfoResolver_test.go index d17fd1aedb4..fe74a43ec51 100644 --- a/dataRetriever/resolvers/validatorInfoResolver_test.go +++ b/dataRetriever/resolvers/validatorInfoResolver_test.go @@ -10,6 +10,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/partitioning" "github.com/multiversx/mx-chain-core-go/data/batch" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/mock" @@ -21,8 +24,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockArgValidatorInfoResolver() resolvers.ArgValidatorInfoResolver { @@ -141,8 +142,9 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { res, _ := resolvers.NewValidatorInfoResolver(createMockArgValidatorInfoResolver()) require.False(t, check.IfNil(res)) - err := res.ProcessReceivedMessage(nil, fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(nil, fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Equal(t, dataRetriever.ErrNilMessage, err) + assert.Nil(t, msgID) }) t.Run("canProcessMessage due to antiflood handler error", func(t *testing.T) { t.Parallel() @@ -156,10 +158,11 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { res, _ := resolvers.NewValidatorInfoResolver(args) require.False(t, check.IfNil(res)) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, nil), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, nil), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) assert.False(t, args.Throttler.(*mock.ThrottlerStub).StartWasCalled()) assert.False(t, args.Throttler.(*mock.ThrottlerStub).EndWasCalled()) + assert.Nil(t, msgID) }) t.Run("parseReceivedMessage returns error due to marshalizer error", func(t *testing.T) { t.Parallel() @@ -173,8 +176,9 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { res, _ := resolvers.NewValidatorInfoResolver(args) require.False(t, check.IfNil(res)) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, nil), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, nil), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, expectedErr)) + assert.Nil(t, msgID) }) t.Run("invalid request type should error", func(t *testing.T) { @@ -183,8 +187,9 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { res, _ := resolvers.NewValidatorInfoResolver(createMockArgValidatorInfoResolver()) require.False(t, check.IfNil(res)) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.NonceType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.True(t, errors.Is(err, dataRetriever.ErrRequestTypeNotImplemented)) + assert.Nil(t, msgID) }) // resolveHashRequest @@ -205,8 +210,9 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { res, _ := resolvers.NewValidatorInfoResolver(args) require.False(t, check.IfNil(res)) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Equal(t, expectedErr, err) + assert.Nil(t, msgID) }) t.Run("data found in cache but marshal fails", func(t *testing.T) { t.Parallel() @@ -229,8 +235,9 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { res, _ := resolvers.NewValidatorInfoResolver(args) require.False(t, check.IfNil(res)) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.NotNil(t, err) + assert.Nil(t, msgID) }) t.Run("data found in storage but marshal fails", func(t *testing.T) { t.Parallel() @@ -258,8 +265,9 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { res, _ := resolvers.NewValidatorInfoResolver(args) require.False(t, check.IfNil(res)) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.NotNil(t, err) + assert.Nil(t, msgID) }) t.Run("should work, data from cache", func(t *testing.T) { t.Parallel() @@ -290,9 +298,10 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { res, _ := resolvers.NewValidatorInfoResolver(args) require.False(t, check.IfNil(res)) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, wasCalled) + assert.Len(t, msgID, 0) }) t.Run("should work, data from storage", func(t *testing.T) { t.Parallel() @@ -329,9 +338,10 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { res, _ := resolvers.NewValidatorInfoResolver(args) require.False(t, check.IfNil(res)) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, wasCalled) + assert.Len(t, msgID, 0) }) // resolveMultipleHashesRequest @@ -353,8 +363,9 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { res, _ := resolvers.NewValidatorInfoResolver(args) require.False(t, check.IfNil(res)) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, []byte("hash")), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Equal(t, expectedErr, err) + assert.Nil(t, msgID) }) t.Run("no hash found", func(t *testing.T) { t.Parallel() @@ -377,9 +388,10 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { Data: [][]byte{[]byte("hash")}, } buff, _ := args.Marshaller.Marshal(b) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, buff), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, buff), fromConnectedPeer, &p2pmocks.MessengerStub{}) require.NotNil(t, err) assert.True(t, strings.Contains(err.Error(), dataRetriever.ErrValidatorInfoNotFound.Error())) + assert.Nil(t, msgID) }) t.Run("pack data in chunks returns error", func(t *testing.T) { t.Parallel() @@ -407,8 +419,9 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { Data: [][]byte{[]byte("hash")}, } buff, _ := args.Marshaller.Marshal(b) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, buff), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, buff), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Equal(t, expectedErr, err) + assert.Nil(t, msgID) }) t.Run("send returns error", func(t *testing.T) { t.Parallel() @@ -441,8 +454,9 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { require.False(t, check.IfNil(res)) buff, _ := args.Marshaller.Marshal(&batch.Batch{Data: providedHashes}) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, buff), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, buff), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Equal(t, expectedErr, err) + assert.Nil(t, msgID) }) t.Run("all hashes in one chunk should work", func(t *testing.T) { t.Parallel() @@ -489,9 +503,10 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { require.False(t, check.IfNil(res)) buff, _ := args.Marshaller.Marshal(&batch.Batch{Data: providedHashes}) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, buff), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, buff), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.True(t, wasCalled) + assert.Len(t, msgID, 0) }) t.Run("multiple chunks should work", func(t *testing.T) { t.Parallel() @@ -551,10 +566,11 @@ func TestValidatorInfoResolver_ProcessReceivedMessage(t *testing.T) { require.False(t, check.IfNil(res)) buff, _ := args.Marshaller.Marshal(&batch.Batch{Data: providedHashes}) - err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, buff), fromConnectedPeer, &p2pmocks.MessengerStub{}) + msgID, err := res.ProcessReceivedMessage(createRequestMsg(dataRetriever.HashArrayType, buff), fromConnectedPeer, &p2pmocks.MessengerStub{}) assert.Nil(t, err) assert.Equal(t, 2, numOfCallsSend) // ~677 messages in a chunk assert.Equal(t, 0, len(providedDataMap)) // all items should have been deleted on Send + assert.Len(t, msgID, 0) }) } diff --git a/dataRetriever/shardedData/shardedData.go b/dataRetriever/shardedData/shardedData.go index 785998164b9..0724473d07b 100644 --- a/dataRetriever/shardedData/shardedData.go +++ b/dataRetriever/shardedData/shardedData.go @@ -22,10 +22,10 @@ const untitledCacheName = "untitled" // shardedData holds the list of data organised by destination shard // -// The shardStores field maps a cacher, containing data -// hashes, to a corresponding identifier. It is able to add or remove -// data given the shard id it is associated with. It can -// also merge and split pools when required +// The shardStores field maps a cacher, containing data +// hashes, to a corresponding identifier. It is able to add or remove +// data given the shard id it is associated with. It can +// also merge and split pools when required type shardedData struct { name string mutShardedDataStore sync.RWMutex @@ -201,7 +201,8 @@ func (sd *shardedData) RemoveData(key []byte, cacheID string) { } // RemoveDataFromAllShards will remove data from the store given only -// the data hash. It will iterate over all shard store map and will remove it everywhere +// +// the data hash. It will iterate over all shard store map and will remove it everywhere func (sd *shardedData) RemoveDataFromAllShards(key []byte) { sd.mutShardedDataStore.RLock() defer sd.mutShardedDataStore.RUnlock() diff --git a/dataRetriever/storageRequesters/equivalentProofsRequester.go b/dataRetriever/storageRequesters/equivalentProofsRequester.go new file mode 100644 index 00000000000..4454aa890cf --- /dev/null +++ b/dataRetriever/storageRequesters/equivalentProofsRequester.go @@ -0,0 +1,144 @@ +package storagerequesters + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/multiversx/mx-chain-core-go/data/typeConverters" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/dataRetriever" + "github.com/multiversx/mx-chain-go/storage" + "time" +) + +// ArgEquivalentProofsRequester is the argument structure used to create a new equivalent proofs requester instance +type ArgEquivalentProofsRequester struct { + Messenger dataRetriever.MessageHandler + ResponseTopicName string + ManualEpochStartNotifier dataRetriever.ManualEpochStartNotifier + ChanGracefullyClose chan endProcess.ArgEndProcess + DelayBeforeGracefulClose time.Duration + NonceConverter typeConverters.Uint64ByteSliceConverter + Storage dataRetriever.StorageService + Marshaller marshal.Marshalizer + EnableEpochsHandler core.EnableEpochsHandler +} + +type equivalentProofsRequester struct { + *storageRequester + nonceConverter typeConverters.Uint64ByteSliceConverter + storage dataRetriever.StorageService + marshaller marshal.Marshalizer + enableEpochsHandler core.EnableEpochsHandler +} + +// NewEquivalentProofsRequester returns a new instance of equivalent proofs requester +func NewEquivalentProofsRequester(args ArgEquivalentProofsRequester) (*equivalentProofsRequester, error) { + err := checkArgs(args) + if err != nil { + return nil, err + } + + return &equivalentProofsRequester{ + storageRequester: &storageRequester{ + messenger: args.Messenger, + responseTopicName: args.ResponseTopicName, + manualEpochStartNotifier: args.ManualEpochStartNotifier, + chanGracefullyClose: args.ChanGracefullyClose, + delayBeforeGracefulClose: args.DelayBeforeGracefulClose, + }, + nonceConverter: args.NonceConverter, + storage: args.Storage, + marshaller: args.Marshaller, + enableEpochsHandler: args.EnableEpochsHandler, + }, nil +} + +func checkArgs(args ArgEquivalentProofsRequester) error { + if check.IfNil(args.Messenger) { + return dataRetriever.ErrNilMessenger + } + if check.IfNil(args.ManualEpochStartNotifier) { + return dataRetriever.ErrNilManualEpochStartNotifier + } + if args.ChanGracefullyClose == nil { + return dataRetriever.ErrNilGracefullyCloseChannel + } + if check.IfNil(args.NonceConverter) { + return dataRetriever.ErrNilUint64ByteSliceConverter + } + if check.IfNil(args.Storage) { + return dataRetriever.ErrNilStore + } + if check.IfNil(args.Marshaller) { + return dataRetriever.ErrNilMarshalizer + } + if check.IfNil(args.EnableEpochsHandler) { + return dataRetriever.ErrNilEnableEpochsHandler + } + + return nil +} + +// RequestDataFromHash requests equivalent proofs data from storage for the specified hash-shard key +func (requester *equivalentProofsRequester) RequestDataFromHash(hashShardKey []byte, epoch uint32) error { + if !requester.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, epoch) { + return nil + } + + headerHash, _, err := common.GetHashAndShardFromKey(hashShardKey) + if err != nil { + return err + } + + equivalentProofsStorage, err := requester.storage.GetStorer(dataRetriever.ProofsUnit) + if err != nil { + return err + } + + buff, err := equivalentProofsStorage.SearchFirst(headerHash) + if err != nil { + return err + } + + return requester.sendToSelf(buff) +} + +// RequestDataFromNonce requests equivalent proofs data from storage for the specified nonce-shard key +func (requester *equivalentProofsRequester) RequestDataFromNonce(nonceShardKey []byte, epoch uint32) error { + if !requester.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, epoch) { + return nil + } + + headerNonce, shardID, err := common.GetNonceAndShardFromKey(nonceShardKey) + if err != nil { + return err + } + storer, err := requester.getStorerForShard(shardID) + if err != nil { + return err + } + + nonceKey := requester.nonceConverter.ToByteSlice(headerNonce) + hash, err := storer.SearchFirst(nonceKey) + if err != nil { + return err + } + + hashShardKey := common.GetEquivalentProofHashShardKey(hash, shardID) + return requester.RequestDataFromHash([]byte(hashShardKey), epoch) +} + +func (requester *equivalentProofsRequester) getStorerForShard(shardID uint32) (storage.Storer, error) { + if shardID == core.MetachainShardId { + return requester.storage.GetStorer(dataRetriever.MetaHdrNonceHashDataUnit) + } + + return requester.storage.GetStorer(dataRetriever.GetHdrNonceHashDataUnit(shardID)) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (requester *equivalentProofsRequester) IsInterfaceNil() bool { + return requester == nil +} diff --git a/dataRetriever/storageRequesters/equivalentProofsRequester_test.go b/dataRetriever/storageRequesters/equivalentProofsRequester_test.go new file mode 100644 index 00000000000..d8ed8b1f08c --- /dev/null +++ b/dataRetriever/storageRequesters/equivalentProofsRequester_test.go @@ -0,0 +1,280 @@ +package storagerequesters + +import ( + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/dataRetriever" + "github.com/multiversx/mx-chain-go/dataRetriever/mock" + chainStorage "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/genericMocks" + "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" + "github.com/multiversx/mx-chain-go/testscommon/storage" + "github.com/stretchr/testify/require" +) + +func createMockArgEquivalentProofsRequester() ArgEquivalentProofsRequester { + return ArgEquivalentProofsRequester{ + Messenger: &mock.MessageHandlerStub{}, + ResponseTopicName: "", + ManualEpochStartNotifier: &mock.ManualEpochStartNotifierStub{}, + ChanGracefullyClose: make(chan endProcess.ArgEndProcess), + NonceConverter: &mock.Uint64ByteSliceConverterMock{ + ToByteSliceCalled: func(u uint64) []byte { + return make([]byte, 0) + }, + ToUint64Called: func(bytes []byte) (uint64, error) { + return 0, nil + }, + }, + Storage: &genericMocks.ChainStorerMock{}, + Marshaller: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + }, + } +} + +func TestNewEquivalentProofsRequester(t *testing.T) { + t.Parallel() + + t.Run("nil Messenger should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsRequester() + args.Messenger = nil + req, err := NewEquivalentProofsRequester(args) + require.Equal(t, dataRetriever.ErrNilMessenger, err) + require.Nil(t, req) + }) + t.Run("nil ManualEpochStartNotifier should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsRequester() + args.ManualEpochStartNotifier = nil + req, err := NewEquivalentProofsRequester(args) + require.Equal(t, dataRetriever.ErrNilManualEpochStartNotifier, err) + require.Nil(t, req) + }) + t.Run("nil ChanGracefullyClose should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsRequester() + args.ChanGracefullyClose = nil + req, err := NewEquivalentProofsRequester(args) + require.Equal(t, dataRetriever.ErrNilGracefullyCloseChannel, err) + require.Nil(t, req) + }) + t.Run("nil NonceConverter should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsRequester() + args.NonceConverter = nil + req, err := NewEquivalentProofsRequester(args) + require.Equal(t, dataRetriever.ErrNilUint64ByteSliceConverter, err) + require.Nil(t, req) + }) + t.Run("nil Storage should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsRequester() + args.Storage = nil + req, err := NewEquivalentProofsRequester(args) + require.Equal(t, dataRetriever.ErrNilStore, err) + require.Nil(t, req) + }) + t.Run("nil Marshaller should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsRequester() + args.Marshaller = nil + req, err := NewEquivalentProofsRequester(args) + require.Equal(t, dataRetriever.ErrNilMarshalizer, err) + require.Nil(t, req) + }) + t.Run("nil EnableEpochsHandler should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsRequester() + args.EnableEpochsHandler = nil + req, err := NewEquivalentProofsRequester(args) + require.Equal(t, dataRetriever.ErrNilEnableEpochsHandler, err) + require.Nil(t, req) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + req, err := NewEquivalentProofsRequester(createMockArgEquivalentProofsRequester()) + require.NoError(t, err) + require.NotNil(t, req) + }) +} + +func TestEquivalentProofsRequester_IsInterfaceNil(t *testing.T) { + var req *equivalentProofsRequester + require.True(t, req.IsInterfaceNil()) + + req, _ = NewEquivalentProofsRequester(createMockArgEquivalentProofsRequester()) + require.False(t, req.IsInterfaceNil()) +} + +func TestEquivalentProofsRequester_RequestDataFromHash(t *testing.T) { + t.Parallel() + + t.Run("invalid key should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsRequester() + req, err := NewEquivalentProofsRequester(args) + require.NoError(t, err) + + err = req.RequestDataFromHash([]byte("invalid key"), 0) + require.Error(t, err) + }) + t.Run("GetStorer error should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsRequester() + args.Storage = &storage.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (chainStorage.Storer, error) { + return nil, expectedErr + }, + } + req, err := NewEquivalentProofsRequester(args) + require.NoError(t, err) + + err = req.RequestDataFromHash([]byte(common.GetEquivalentProofHashShardKey([]byte("hash"), 1)), 0) + require.Equal(t, expectedErr, err) + }) + t.Run("SearchFirst error should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsRequester() + args.Storage = &storage.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (chainStorage.Storer, error) { + return &storage.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + return nil, expectedErr + }, + }, nil + }, + } + req, err := NewEquivalentProofsRequester(args) + require.NoError(t, err) + + err = req.RequestDataFromHash([]byte(common.GetEquivalentProofHashShardKey([]byte("hash"), 1)), 0) + require.Equal(t, expectedErr, err) + }) + t.Run("should work and send to self", func(t *testing.T) { + t.Parallel() + + providedBuff := []byte("provided buff") + args := createMockArgEquivalentProofsRequester() + args.Storage = &storage.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (chainStorage.Storer, error) { + return &storage.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + return providedBuff, nil + }, + }, nil + }, + } + wasSendToConnectedPeerCalled := false + args.Messenger = &p2pmocks.MessengerStub{ + SendToConnectedPeerCalled: func(topic string, buff []byte, peerID core.PeerID) error { + wasSendToConnectedPeerCalled = true + require.Equal(t, string(providedBuff), string(buff)) + return nil + }, + } + req, err := NewEquivalentProofsRequester(args) + require.NoError(t, err) + + err = req.RequestDataFromHash([]byte(common.GetEquivalentProofHashShardKey([]byte("hash"), 1)), 0) + require.NoError(t, err) + require.True(t, wasSendToConnectedPeerCalled) + }) +} + +func TestEquivalentProofsRequester_RequestDataFromNonce(t *testing.T) { + t.Parallel() + + t.Run("invalid key should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsRequester() + req, err := NewEquivalentProofsRequester(args) + require.NoError(t, err) + + err = req.RequestDataFromNonce([]byte("invalid key"), 0) + require.Error(t, err) + }) + t.Run("getStorerForShard error should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsRequester() + args.Storage = &storage.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (chainStorage.Storer, error) { + return nil, expectedErr + }, + } + req, err := NewEquivalentProofsRequester(args) + require.NoError(t, err) + + err = req.RequestDataFromNonce([]byte(common.GetEquivalentProofNonceShardKey(123, 1)), 0) + require.Equal(t, expectedErr, err) + }) + t.Run("SearchFirst error should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgEquivalentProofsRequester() + args.Storage = &storage.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (chainStorage.Storer, error) { + return &storage.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + return nil, expectedErr + }, + }, nil + }, + } + req, err := NewEquivalentProofsRequester(args) + require.NoError(t, err) + + err = req.RequestDataFromNonce([]byte(common.GetEquivalentProofNonceShardKey(123, core.MetachainShardId)), 0) + require.Equal(t, expectedErr, err) + }) + t.Run("should work and send to self", func(t *testing.T) { + t.Parallel() + + providedBuff := []byte("provided buff") + args := createMockArgEquivalentProofsRequester() + args.Storage = &storage.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (chainStorage.Storer, error) { + return &storage.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + return providedBuff, nil + }, + }, nil + }, + } + wasSendToConnectedPeerCalled := false + args.Messenger = &p2pmocks.MessengerStub{ + SendToConnectedPeerCalled: func(topic string, buff []byte, peerID core.PeerID) error { + wasSendToConnectedPeerCalled = true + require.Equal(t, string(providedBuff), string(buff)) + return nil + }, + } + req, err := NewEquivalentProofsRequester(args) + require.NoError(t, err) + + err = req.RequestDataFromNonce([]byte(common.GetEquivalentProofNonceShardKey(123, 1)), 0) + require.NoError(t, err) + require.True(t, wasSendToConnectedPeerCalled) + }) +} diff --git a/dataRetriever/unitType.go b/dataRetriever/unitType.go index 22bba7dc2b8..66775fb6179 100644 --- a/dataRetriever/unitType.go +++ b/dataRetriever/unitType.go @@ -41,7 +41,7 @@ const ( TrieEpochRootHashUnit UnitType = 17 // ESDTSuppliesUnit is the ESDT supplies storage unit identifier ESDTSuppliesUnit UnitType = 18 - // RoundHdrHashDataUnit is the round- block header hash storage data unit identifier + // RoundHdrHashDataUnit is the round-block header hash storage data unit identifier RoundHdrHashDataUnit UnitType = 19 // UserAccountsUnit is the user accounts storage unit identifier UserAccountsUnit UnitType = 20 @@ -49,6 +49,8 @@ const ( PeerAccountsUnit UnitType = 21 // ScheduledSCRsUnit is the scheduled SCRs storage unit identifier ScheduledSCRsUnit UnitType = 22 + // ProofsUnit is the header proofs unit identifier + ProofsUnit UnitType = 23 // ShardHdrNonceHashDataUnit is the header nonce-hash pair data unit identifier //TODO: Add only unit types lower than 100 @@ -110,6 +112,8 @@ func (ut UnitType) String() string { return "PeerAccountsUnit" case ScheduledSCRsUnit: return "ScheduledSCRsUnit" + case ProofsUnit: + return "ProofsUnit" } if ut < ShardHdrNonceHashDataUnit { diff --git a/epochStart/bootstrap/baseStorageHandler.go b/epochStart/bootstrap/baseStorageHandler.go index d7a18094a07..f0abdd78257 100644 --- a/epochStart/bootstrap/baseStorageHandler.go +++ b/epochStart/bootstrap/baseStorageHandler.go @@ -37,6 +37,8 @@ type StorageHandlerArgs struct { NodeProcessingMode common.NodeProcessingMode RepopulateTokensSupplies bool StateStatsHandler common.StateStatisticsHandler + ProofsPool ProofsPool + EnableEpochsHandler common.EnableEpochsHandler } func checkNilArgs(args StorageHandlerArgs) error { @@ -58,6 +60,13 @@ func checkNilArgs(args StorageHandlerArgs) error { if check.IfNil(args.NodesCoordinatorRegistryFactory) { return nodesCoordinator.ErrNilNodesCoordinatorRegistryFactory } + if check.IfNil(args.ProofsPool) { + return dataRetriever.ErrNilProofsPool + } + if check.IfNil(args.EnableEpochsHandler) { + return core.ErrNilEnableEpochsHandler + } + return nil } @@ -83,6 +92,8 @@ type baseStorageHandler struct { currentEpoch uint32 uint64Converter typeConverters.Uint64ByteSliceConverter nodesCoordinatorRegistryFactory nodesCoordinator.NodesCoordinatorRegistryFactory + proofsPool ProofsPool + enableEpochsHandler common.EnableEpochsHandler } func (bsh *baseStorageHandler) groupMiniBlocksByShard(miniBlocks map[string]*block.MiniBlock) ([]bootstrapStorage.PendingMiniBlocksInfo, error) { @@ -103,6 +114,34 @@ func (bsh *baseStorageHandler) groupMiniBlocksByShard(miniBlocks map[string]*blo return sliceToRet, nil } +func (bsh *baseStorageHandler) saveProofToStorage(shardID uint32, headerHash []byte, header data.HeaderHandler) error { + if !bsh.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + return nil + } + + proof, err := bsh.proofsPool.GetProof(shardID, headerHash) + if err != nil { + return err + } + + proofsStorer, err := bsh.storageService.GetStorer(dataRetriever.ProofsUnit) + if err != nil { + return err + } + + marshalledProof, errMarshal := bsh.marshalizer.Marshal(proof) + if errMarshal != nil { + return errMarshal + } + + errPut := proofsStorer.Put(proof.GetHeaderHash(), marshalledProof) + if errPut != nil { + return errPut + } + + return nil +} + func (bsh *baseStorageHandler) saveNodesCoordinatorRegistry( metaBlock data.HeaderHandler, nodesConfig nodesCoordinator.NodesCoordinatorRegistryHandler, @@ -158,6 +197,11 @@ func (bsh *baseStorageHandler) saveMetaHdrToStorage(metaBlock data.HeaderHandler return nil, err } + err = bsh.saveProofToStorage(core.MetachainShardId, headerHash, metaBlock) + if err != nil { + return nil, err + } + return headerHash, nil } @@ -197,6 +241,11 @@ func (bsh *baseStorageHandler) saveShardHdrToStorage(hdr data.HeaderHandler) ([] return nil, err } + err = bsh.saveProofToStorage(hdr.GetShardID(), headerHash, hdr) + if err != nil { + return nil, err + } + return headerHash, nil } diff --git a/epochStart/bootstrap/common.go b/epochStart/bootstrap/common.go index da6e99fda1b..a6621f86ed8 100644 --- a/epochStart/bootstrap/common.go +++ b/epochStart/bootstrap/common.go @@ -123,6 +123,9 @@ func checkArguments(args ArgsEpochStartBootstrap) error { if check.IfNil(args.NodesCoordinatorRegistryFactory) { return fmt.Errorf("%s: %w", baseErrorMessage, nodesCoordinator.ErrNilNodesCoordinatorRegistryFactory) } + if check.IfNil(args.EnableEpochsHandler) { + return fmt.Errorf("%s: %w", baseErrorMessage, epochStart.ErrNilEnableEpochsHandler) + } return nil } diff --git a/epochStart/bootstrap/disabled/disabledAntiFloodHandler.go b/epochStart/bootstrap/disabled/disabledAntiFloodHandler.go index cc1065b9d98..96d656b9e13 100644 --- a/epochStart/bootstrap/disabled/disabledAntiFloodHandler.go +++ b/epochStart/bootstrap/disabled/disabledAntiFloodHandler.go @@ -29,8 +29,8 @@ func (a *antiFloodHandler) CanProcessMessagesOnTopic(_ core.PeerID, _ string, _ return nil } -// ApplyConsensusSize does nothing -func (a *antiFloodHandler) ApplyConsensusSize(_ int) { +// SetConsensusSizeNotifier does nothing +func (a *antiFloodHandler) SetConsensusSizeNotifier(_ process.ChainParametersSubscriber, _ uint32) { } // SetDebugger returns nil diff --git a/epochStart/bootstrap/disabled/disabledHeaderSigVerifier.go b/epochStart/bootstrap/disabled/disabledHeaderSigVerifier.go index d5de2e34380..7245930ab27 100644 --- a/epochStart/bootstrap/disabled/disabledHeaderSigVerifier.go +++ b/epochStart/bootstrap/disabled/disabledHeaderSigVerifier.go @@ -2,6 +2,7 @@ package disabled import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" ) @@ -15,27 +16,37 @@ func NewHeaderSigVerifier() *headerSigVerifier { return &headerSigVerifier{} } -// VerifyRandSeed - +// VerifyRandSeed returns nil as it is disabled func (h *headerSigVerifier) VerifyRandSeed(_ data.HeaderHandler) error { return nil } -// VerifyLeaderSignature - +// VerifyLeaderSignature returns nil as it is disabled func (h *headerSigVerifier) VerifyLeaderSignature(_ data.HeaderHandler) error { return nil } -// VerifyRandSeedAndLeaderSignature - +// VerifyRandSeedAndLeaderSignature returns nil as it is disabled func (h *headerSigVerifier) VerifyRandSeedAndLeaderSignature(_ data.HeaderHandler) error { return nil } -// VerifySignature - +// VerifySignature returns nil as it is disabled func (h *headerSigVerifier) VerifySignature(_ data.HeaderHandler) error { return nil } -// IsInterfaceNil - +// VerifySignatureForHash returns nil as it is disabled +func (h *headerSigVerifier) VerifySignatureForHash(_ data.HeaderHandler, _ []byte, _ []byte, _ []byte) error { + return nil +} + +// VerifyHeaderProof returns nil as it is disabled +func (h *headerSigVerifier) VerifyHeaderProof(_ data.HeaderProofHandler) error { + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface func (h *headerSigVerifier) IsInterfaceNil() bool { return h == nil } diff --git a/epochStart/bootstrap/disabled/disabledNodesCoordinator.go b/epochStart/bootstrap/disabled/disabledNodesCoordinator.go index f7c1502d0c4..031255aad93 100644 --- a/epochStart/bootstrap/disabled/disabledNodesCoordinator.go +++ b/epochStart/bootstrap/disabled/disabledNodesCoordinator.go @@ -44,6 +44,11 @@ func (n *nodesCoordinator) GetAllEligibleValidatorsPublicKeys(_ uint32) (map[uin return nil, nil } +// GetAllEligibleValidatorsPublicKeysForShard - +func (n *nodesCoordinator) GetAllEligibleValidatorsPublicKeysForShard(_ uint32, _ uint32) ([]string, error) { + return nil, nil +} + // GetAllWaitingValidatorsPublicKeys - func (n *nodesCoordinator) GetAllWaitingValidatorsPublicKeys(_ uint32) (map[uint32][][]byte, error) { return nil, nil @@ -60,8 +65,8 @@ func (n *nodesCoordinator) GetShuffledOutToAuctionValidatorsPublicKeys(_ uint32) } // GetConsensusValidatorsPublicKeys - -func (n *nodesCoordinator) GetConsensusValidatorsPublicKeys(_ []byte, _ uint64, _ uint32, _ uint32) ([]string, error) { - return nil, nil +func (n *nodesCoordinator) GetConsensusValidatorsPublicKeys(_ []byte, _ uint64, _ uint32, _ uint32) (string, []string, error) { + return "", nil, nil } // GetOwnPublicKey - @@ -70,8 +75,8 @@ func (n *nodesCoordinator) GetOwnPublicKey() []byte { } // ComputeConsensusGroup - -func (n *nodesCoordinator) ComputeConsensusGroup(_ []byte, _ uint64, _ uint32, _ uint32) (validatorsGroup []nodesCoord.Validator, err error) { - return nil, nil +func (n *nodesCoordinator) ComputeConsensusGroup(_ []byte, _ uint64, _ uint32, _ uint32) (leader nodesCoord.Validator, validatorsGroup []nodesCoord.Validator, err error) { + return nil, nil, nil } // GetValidatorWithPublicKey - @@ -103,8 +108,8 @@ func (n *nodesCoordinator) GetConsensusWhitelistedNodes(_ uint32) (map[string]st return nil, nil } -// ConsensusGroupSize - -func (n *nodesCoordinator) ConsensusGroupSize(uint32) int { +// ConsensusGroupSizeForShardAndEpoch - +func (n *nodesCoordinator) ConsensusGroupSizeForShardAndEpoch(uint32, uint32) int { return 0 } @@ -118,6 +123,11 @@ func (n *nodesCoordinator) GetWaitingEpochsLeftForPublicKey(_ []byte) (uint32, e return 0, nil } +// GetCachedEpochs returns an empty map +func (n *nodesCoordinator) GetCachedEpochs() map[uint32]struct{} { + return make(map[uint32]struct{}) +} + // IsInterfaceNil - func (n *nodesCoordinator) IsInterfaceNil() bool { return n == nil diff --git a/epochStart/bootstrap/epochStartMetaBlockProcessor.go b/epochStart/bootstrap/epochStartMetaBlockProcessor.go index ff1a4370ad7..419a243f3fd 100644 --- a/epochStart/bootstrap/epochStartMetaBlockProcessor.go +++ b/epochStart/bootstrap/epochStartMetaBlockProcessor.go @@ -12,6 +12,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" @@ -26,15 +28,21 @@ const minNumConnectedPeers = 1 var _ process.InterceptorProcessor = (*epochStartMetaBlockProcessor)(nil) type epochStartMetaBlockProcessor struct { - messenger Messenger - requestHandler RequestHandler - marshalizer marshal.Marshalizer - hasher hashing.Hasher - mutReceivedMetaBlocks sync.RWMutex - mapReceivedMetaBlocks map[string]data.MetaHeaderHandler - mapMetaBlocksFromPeers map[string][]core.PeerID - chanConsensusReached chan bool + messenger Messenger + requestHandler RequestHandler + marshalizer marshal.Marshalizer + hasher hashing.Hasher + enableEpochsHandler common.EnableEpochsHandler + proofsPool ProofsPool + + mutReceivedMetaBlocks sync.RWMutex + mapReceivedMetaBlocks map[string]data.MetaHeaderHandler + mapMetaBlocksFromPeers map[string][]core.PeerID + + chanMetaBlockProofReached chan bool + chanMetaBlockReached chan bool metaBlock data.MetaHeaderHandler + metaBlockHash string peerCountTarget int minNumConnectedPeers int minNumOfPeersToConsiderBlockValid int @@ -49,6 +57,8 @@ func NewEpochStartMetaBlockProcessor( consensusPercentage uint8, minNumConnectedPeersConfig int, minNumOfPeersToConsiderBlockValidConfig int, + enableEpochsHandler common.EnableEpochsHandler, + proofsPool ProofsPool, ) (*epochStartMetaBlockProcessor, error) { if check.IfNil(messenger) { return nil, epochStart.ErrNilMessenger @@ -71,6 +81,12 @@ func NewEpochStartMetaBlockProcessor( if minNumOfPeersToConsiderBlockValidConfig < minNumPeersToConsiderMetaBlockValid { return nil, epochStart.ErrNotEnoughNumOfPeersToConsiderBlockValid } + if check.IfNil(enableEpochsHandler) { + return nil, epochStart.ErrNilEnableEpochsHandler + } + if check.IfNil(proofsPool) { + return nil, epochStart.ErrNilProofsPool + } processor := &epochStartMetaBlockProcessor{ messenger: messenger, @@ -79,12 +95,17 @@ func NewEpochStartMetaBlockProcessor( hasher: hasher, minNumConnectedPeers: minNumConnectedPeersConfig, minNumOfPeersToConsiderBlockValid: minNumOfPeersToConsiderBlockValidConfig, + enableEpochsHandler: enableEpochsHandler, mutReceivedMetaBlocks: sync.RWMutex{}, mapReceivedMetaBlocks: make(map[string]data.MetaHeaderHandler), mapMetaBlocksFromPeers: make(map[string][]core.PeerID), - chanConsensusReached: make(chan bool, 1), + chanMetaBlockProofReached: make(chan bool, 1), + chanMetaBlockReached: make(chan bool, 1), + proofsPool: proofsPool, } + proofsPool.RegisterHandler(processor.receivedProof) + processor.waitForEnoughNumConnectedPeers(messenger) percentage := float64(consensusPercentage) / 100.0 peerCountTarget := int(percentage * float64(len(messenger.ConnectedPeers()))) @@ -136,18 +157,17 @@ func (e *epochStartMetaBlockProcessor) Save(data process.InterceptedData, fromCo return nil } - if !metaBlock.IsStartOfEpochBlock() { - log.Debug("received metablock is not of type epoch start", "error", epochStart.ErrNotEpochStartBlock) - return nil - } - mbHash := interceptedHdr.Hash() - log.Debug("received epoch start meta", "epoch", metaBlock.GetEpoch(), "from peer", fromConnectedPeer.Pretty()) - e.mutReceivedMetaBlocks.Lock() - e.mapReceivedMetaBlocks[string(mbHash)] = metaBlock - e.addToPeerList(string(mbHash), fromConnectedPeer) - e.mutReceivedMetaBlocks.Unlock() + if metaBlock.IsStartOfEpochBlock() { + log.Debug("received epoch start meta block", "epoch", metaBlock.GetEpoch(), "from peer", fromConnectedPeer.Pretty()) + e.mutReceivedMetaBlocks.Lock() + e.mapReceivedMetaBlocks[string(mbHash)] = metaBlock + e.addToPeerList(string(mbHash), fromConnectedPeer) + e.mutReceivedMetaBlocks.Unlock() + + return nil + } return nil } @@ -180,33 +200,82 @@ func (e *epochStartMetaBlockProcessor) GetEpochStartMetaBlock(ctx context.Contex } }() - err = e.requestMetaBlock() + metaBlock, metaBlockHash, err := e.waitForMetaBlock(ctx) if err != nil { return nil, err } + e.requestHandler.SetEpoch(metaBlock.GetEpoch()) + if e.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, metaBlock.GetEpoch()) { + err = e.waitForMetaBlockProof(ctx, []byte(metaBlockHash)) + if err != nil { + return nil, err + } + } + + return metaBlock, nil +} + +func (e *epochStartMetaBlockProcessor) waitForMetaBlock(ctx context.Context) (data.MetaHeaderHandler, string, error) { + err := e.requestMetaBlock() + if err != nil { + return nil, "", err + } + chanRequests := time.After(durationBetweenReRequests) chanCheckMaps := time.After(durationBetweenChecks) + for { select { - case <-e.chanConsensusReached: - return e.metaBlock, nil + case <-e.chanMetaBlockReached: + return e.metaBlock, e.metaBlockHash, nil case <-ctx.Done(): return e.getMostReceivedMetaBlock() case <-chanRequests: err = e.requestMetaBlock() if err != nil { - return nil, err + return nil, "", err } chanRequests = time.After(durationBetweenReRequests) case <-chanCheckMaps: - e.checkMaps() + e.checkMetaBlockMaps() chanCheckMaps = time.After(durationBetweenChecks) } } } -func (e *epochStartMetaBlockProcessor) getMostReceivedMetaBlock() (data.MetaHeaderHandler, error) { +func (e *epochStartMetaBlockProcessor) waitForMetaBlockProof( + ctx context.Context, + metaBlockHash []byte, +) error { + if e.proofsPool.HasProof(core.MetachainShardId, metaBlockHash) { + return nil + } + + err := e.requestProofForMetaBlock(metaBlockHash) + if err != nil { + return err + } + + chanRequests := time.After(durationBetweenReRequests) + + for { + select { + case <-e.chanMetaBlockProofReached: + return nil + case <-ctx.Done(): + return epochStart.ErrTimeoutWaitingForMetaBlock + case <-chanRequests: + err = e.requestProofForMetaBlock(metaBlockHash) + if err != nil { + return err + } + chanRequests = time.After(durationBetweenReRequests) + } + } +} + +func (e *epochStartMetaBlockProcessor) getMostReceivedMetaBlock() (data.MetaHeaderHandler, string, error) { e.mutReceivedMetaBlocks.RLock() defer e.mutReceivedMetaBlocks.RUnlock() @@ -220,10 +289,10 @@ func (e *epochStartMetaBlockProcessor) getMostReceivedMetaBlock() (data.MetaHead } if len(mostReceivedHash) == 0 { - return nil, epochStart.ErrTimeoutWaitingForMetaBlock + return nil, "", epochStart.ErrTimeoutWaitingForMetaBlock } - return e.mapReceivedMetaBlocks[mostReceivedHash], nil + return e.mapReceivedMetaBlocks[mostReceivedHash], mostReceivedHash, nil } func (e *epochStartMetaBlockProcessor) requestMetaBlock() error { @@ -238,27 +307,74 @@ func (e *epochStartMetaBlockProcessor) requestMetaBlock() error { return nil } -func (e *epochStartMetaBlockProcessor) checkMaps() { +func (e *epochStartMetaBlockProcessor) requestProofForMetaBlock(metablockHash []byte) error { + numConnectedPeers := len(e.messenger.ConnectedPeers()) + topic := common.EquivalentProofsTopic + core.CommunicationIdentifierBetweenShards(core.MetachainShardId, core.AllShardId) + err := e.requestHandler.SetNumPeersToQuery(topic, numConnectedPeers, numConnectedPeers) + if err != nil { + return err + } + + e.requestHandler.RequestEquivalentProofByHash(core.MetachainShardId, metablockHash) + + return nil +} + +func (e *epochStartMetaBlockProcessor) receivedProof(proof data.HeaderProofHandler) { + startOfEpochMetaBlock, hash, err := e.getMostReceivedMetaBlock() + if err != nil { + return + } + + hashesMatchMostReceived := string(proof.GetHeaderHash()) == hash + hashesMatchLocal := string(proof.GetHeaderHash()) == e.metaBlockHash + if !hashesMatchMostReceived && !hashesMatchLocal { + return + } + + metaBlock := e.metaBlock + if hashesMatchMostReceived { + metaBlock = startOfEpochMetaBlock + } + + err = common.VerifyProofAgainstHeader(proof, metaBlock) + if err != nil { + return + } + + e.chanMetaBlockProofReached <- true +} + +func (e *epochStartMetaBlockProcessor) checkMetaBlockMaps() { e.mutReceivedMetaBlocks.RLock() defer e.mutReceivedMetaBlocks.RUnlock() - for hash, peersList := range e.mapMetaBlocksFromPeers { + hash, metaBlockFound := e.checkReceivedMetaBlock(e.mapMetaBlocksFromPeers) + if metaBlockFound { + e.metaBlock = e.mapReceivedMetaBlocks[hash] + e.metaBlockHash = hash + e.chanMetaBlockReached <- true + } +} + +func (e *epochStartMetaBlockProcessor) checkReceivedMetaBlock(blocksFromPeers map[string][]core.PeerID) (string, bool) { + for hash, peersList := range blocksFromPeers { log.Debug("metablock from peers", "num peers", len(peersList), "target", e.peerCountTarget, "hash", []byte(hash)) - found := e.processEntry(peersList, hash) - if found { - break + + metaBlockFound := e.processMetaBlockEntry(peersList) + if metaBlockFound { + return hash, true } } + + return "", false } -func (e *epochStartMetaBlockProcessor) processEntry( +func (e *epochStartMetaBlockProcessor) processMetaBlockEntry( peersList []core.PeerID, - hash string, ) bool { if len(peersList) >= e.peerCountTarget { log.Info("got consensus for epoch start metablock", "len", len(peersList)) - e.metaBlock = e.mapReceivedMetaBlocks[hash] - e.chanConsensusReached <- true return true } diff --git a/epochStart/bootstrap/epochStartMetaBlockProcessor_test.go b/epochStart/bootstrap/epochStartMetaBlockProcessor_test.go index 1741c63a25c..2f550842284 100644 --- a/epochStart/bootstrap/epochStartMetaBlockProcessor_test.go +++ b/epochStart/bootstrap/epochStartMetaBlockProcessor_test.go @@ -9,9 +9,12 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/stretchr/testify/assert" @@ -28,6 +31,8 @@ func TestNewEpochStartMetaBlockProcessor_NilMessengerShouldErr(t *testing.T) { 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Equal(t, epochStart.ErrNilMessenger, err) @@ -45,6 +50,8 @@ func TestNewEpochStartMetaBlockProcessor_NilRequestHandlerShouldErr(t *testing.T 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Equal(t, epochStart.ErrNilRequestHandler, err) @@ -62,6 +69,8 @@ func TestNewEpochStartMetaBlockProcessor_NilMarshalizerShouldErr(t *testing.T) { 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Equal(t, epochStart.ErrNilMarshalizer, err) @@ -79,6 +88,8 @@ func TestNewEpochStartMetaBlockProcessor_NilHasherShouldErr(t *testing.T) { 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Equal(t, epochStart.ErrNilHasher, err) @@ -96,6 +107,8 @@ func TestNewEpochStartMetaBlockProcessor_InvalidConsensusPercentageShouldErr(t * 101, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Equal(t, epochStart.ErrInvalidConsensusThreshold, err) @@ -116,6 +129,8 @@ func TestNewEpochStartMetaBlockProcessorOkValsShouldWork(t *testing.T) { 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.NoError(t, err) @@ -152,6 +167,8 @@ func TestNewEpochStartMetaBlockProcessorOkValsShouldWorkAfterMoreTriesWaitingFor 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.NoError(t, err) @@ -172,6 +189,8 @@ func TestEpochStartMetaBlockProcessor_Validate(t *testing.T) { 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Nil(t, esmbp.Validate(nil, "")) @@ -191,6 +210,8 @@ func TestEpochStartMetaBlockProcessor_SaveNilInterceptedDataShouldNotReturnError 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) err := esmbp.Save(nil, "peer0", "") @@ -212,6 +233,8 @@ func TestEpochStartMetaBlockProcessor_SaveOkInterceptedDataShouldWork(t *testing 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Zero(t, len(esmbp.GetMapMetaBlock())) @@ -241,6 +264,8 @@ func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlockShouldTimeOut(t *tes 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) @@ -264,21 +289,31 @@ func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlockShouldReturnMostRece &hashingMocks.HasherMock{}, 99, 3, - 3, + 5, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) expectedMetaBlock := &block.MetaBlock{ Nonce: 10, EpochStart: block.EpochStart{LastFinalizedHeaders: []block.EpochStartShardData{{Round: 1}}}, } + confirmationMetaBlock := &block.MetaBlock{ + Nonce: 11, + } intData := mock.NewInterceptedMetaBlockMock(expectedMetaBlock, []byte("hash")) + intData2 := mock.NewInterceptedMetaBlockMock(confirmationMetaBlock, []byte("hash2")) for i := 0; i < esmbp.minNumOfPeersToConsiderBlockValid; i++ { _ = esmbp.Save(intData, core.PeerID(fmt.Sprintf("peer_%d", i)), "") } + for i := 0; i < esmbp.minNumOfPeersToConsiderBlockValid; i++ { + _ = esmbp.Save(intData2, core.PeerID(fmt.Sprintf("peer_%d", i)), "") + } + // we need a slightly more time than 1 second in order to also properly test the select branches - timeout := time.Second + time.Millisecond*500 + timeout := 2*time.Second + time.Millisecond*500 ctx, cancel := context.WithTimeout(context.Background(), timeout) mb, err := esmbp.GetEpochStartMetaBlock(ctx) cancel() @@ -301,18 +336,28 @@ func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlockShouldWorkFromFirstT 50, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) expectedMetaBlock := &block.MetaBlock{ Nonce: 10, EpochStart: block.EpochStart{LastFinalizedHeaders: []block.EpochStartShardData{{Round: 1}}}, } + confirmationMetaBlock := &block.MetaBlock{ + Nonce: 11, + } intData := mock.NewInterceptedMetaBlockMock(expectedMetaBlock, []byte("hash")) + intData2 := mock.NewInterceptedMetaBlockMock(confirmationMetaBlock, []byte("hash2")) for i := 0; i < 6; i++ { _ = esmbp.Save(intData, core.PeerID(fmt.Sprintf("peer_%d", i)), "") } + for i := 0; i < 6; i++ { + _ = esmbp.Save(intData2, core.PeerID(fmt.Sprintf("peer_%d", i)), "") + } + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) mb, err := esmbp.GetEpochStartMetaBlock(ctx) cancel() @@ -320,19 +365,54 @@ func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlockShouldWorkFromFirstT assert.Equal(t, expectedMetaBlock, mb) } -func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlockShouldWorkAfterMultipleTries(t *testing.T) { +func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlock_BeforeAndromeda(t *testing.T) { t.Parallel() - testEpochStartMbIsReceivedWithSleepBetweenReceivedMessages(t, durationBetweenChecks-10*time.Millisecond) + tts := durationBetweenChecks - 10*time.Millisecond + + esmbp, _ := NewEpochStartMetaBlockProcessor( + &p2pmocks.MessengerStub{ + ConnectedPeersCalled: func() []core.PeerID { + return []core.PeerID{"peer_0", "peer_1", "peer_2", "peer_3", "peer_4", "peer_5"} + }, + }, + &testscommon.RequestHandlerStub{}, + &mock.MarshalizerMock{}, + &hashingMocks.HasherMock{}, + 64, + 3, + 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) + expectedMetaBlock := &block.MetaBlock{ + Nonce: 10, + EpochStart: block.EpochStart{LastFinalizedHeaders: []block.EpochStartShardData{{Round: 1}}}, + } + intData := mock.NewInterceptedMetaBlockMock(expectedMetaBlock, []byte("hash")) + + go func() { + index := 0 + for { + time.Sleep(tts) + _ = esmbp.Save(intData, core.PeerID(fmt.Sprintf("peer_%d", index)), "") + _ = esmbp.Save(intData, core.PeerID(fmt.Sprintf("peer_%d", index+1)), "") + index += 2 + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + mb, err := esmbp.GetEpochStartMetaBlock(ctx) + cancel() + assert.NoError(t, err) + assert.Equal(t, expectedMetaBlock, mb) } -func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlockShouldWorkAfterMultipleRequests(t *testing.T) { +func TestEpochStartMetaBlockProcessor_GetEpochStartMetaBlock_AfterAndromeda(t *testing.T) { t.Parallel() - testEpochStartMbIsReceivedWithSleepBetweenReceivedMessages(t, durationBetweenChecks-10*time.Millisecond) -} + tts := durationBetweenChecks - 10*time.Millisecond -func testEpochStartMbIsReceivedWithSleepBetweenReceivedMessages(t *testing.T, tts time.Duration) { esmbp, _ := NewEpochStartMetaBlockProcessor( &p2pmocks.MessengerStub{ ConnectedPeersCalled: func() []core.PeerID { @@ -345,12 +425,28 @@ func testEpochStartMbIsReceivedWithSleepBetweenReceivedMessages(t *testing.T, tt 64, 3, 3, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + }, + &dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true + }, + }, ) expectedMetaBlock := &block.MetaBlock{ Nonce: 10, EpochStart: block.EpochStart{LastFinalizedHeaders: []block.EpochStartShardData{{Round: 1}}}, } intData := mock.NewInterceptedMetaBlockMock(expectedMetaBlock, []byte("hash")) + + confirmationMetaBlock := &block.MetaBlock{ + Nonce: 11, + } + intData2 := mock.NewInterceptedMetaBlockMock(confirmationMetaBlock, []byte("hash2")) + go func() { index := 0 for { @@ -360,6 +456,17 @@ func testEpochStartMbIsReceivedWithSleepBetweenReceivedMessages(t *testing.T, tt index += 2 } }() + + go func() { + index := 0 + for { + time.Sleep(tts) + _ = esmbp.Save(intData2, core.PeerID(fmt.Sprintf("peer_%d", index)), "") + _ = esmbp.Save(intData2, core.PeerID(fmt.Sprintf("peer_%d", index+1)), "") + index += 2 + } + }() + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) mb, err := esmbp.GetEpochStartMetaBlock(ctx) cancel() diff --git a/epochStart/bootstrap/factory/epochStartInterceptorsContainerFactory.go b/epochStart/bootstrap/factory/epochStartInterceptorsContainerFactory.go index d659989896b..8700b1daa24 100644 --- a/epochStart/bootstrap/factory/epochStartInterceptorsContainerFactory.go +++ b/epochStart/bootstrap/factory/epochStartInterceptorsContainerFactory.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/typeConverters" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -25,23 +26,24 @@ const timeSpanForBadHeaders = time.Minute // ArgsEpochStartInterceptorContainer holds the arguments needed for creating a new epoch start interceptors // container factory type ArgsEpochStartInterceptorContainer struct { - CoreComponents process.CoreComponentsHolder - CryptoComponents process.CryptoComponentsHolder - Config config.Config - ShardCoordinator sharding.Coordinator - MainMessenger process.TopicHandler - FullArchiveMessenger process.TopicHandler - DataPool dataRetriever.PoolsHolder - WhiteListHandler update.WhiteListHandler - WhiteListerVerifiedTxs update.WhiteListHandler - AddressPubkeyConv core.PubkeyConverter - NonceConverter typeConverters.Uint64ByteSliceConverter - ChainID []byte - ArgumentsParser process.ArgumentsParser - HeaderIntegrityVerifier process.HeaderIntegrityVerifier - RequestHandler process.RequestHandler - SignaturesHandler process.SignaturesHandler - NodeOperationMode common.NodeOperation + CoreComponents process.CoreComponentsHolder + CryptoComponents process.CryptoComponentsHolder + Config config.Config + ShardCoordinator sharding.Coordinator + MainMessenger process.TopicHandler + FullArchiveMessenger process.TopicHandler + DataPool dataRetriever.PoolsHolder + WhiteListHandler update.WhiteListHandler + WhiteListerVerifiedTxs update.WhiteListHandler + AddressPubkeyConv core.PubkeyConverter + NonceConverter typeConverters.Uint64ByteSliceConverter + ChainID []byte + ArgumentsParser process.ArgumentsParser + HeaderIntegrityVerifier process.HeaderIntegrityVerifier + RequestHandler process.RequestHandler + SignaturesHandler process.SignaturesHandler + NodeOperationMode common.NodeOperation + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // NewEpochStartInterceptorsContainer will return a real interceptors container factory, but with many disabled components @@ -78,36 +80,37 @@ func NewEpochStartInterceptorsContainer(args ArgsEpochStartInterceptorContainer) hardforkTrigger := disabledFactory.HardforkTrigger() containerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: args.CoreComponents, - CryptoComponents: cryptoComponents, - Accounts: accountsAdapter, - ShardCoordinator: args.ShardCoordinator, - NodesCoordinator: nodesCoordinator, - MainMessenger: args.MainMessenger, - FullArchiveMessenger: args.FullArchiveMessenger, - Store: storer, - DataPool: args.DataPool, - MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, - TxFeeHandler: feeHandler, - BlockBlackList: blackListHandler, - HeaderSigVerifier: headerSigVerifier, - HeaderIntegrityVerifier: args.HeaderIntegrityVerifier, - ValidityAttester: validityAttester, - EpochStartTrigger: epochStartTrigger, - WhiteListHandler: args.WhiteListHandler, - WhiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, - AntifloodHandler: antiFloodHandler, - ArgumentsParser: args.ArgumentsParser, - PreferredPeersHolder: disabled.NewPreferredPeersHolder(), - SizeCheckDelta: uint32(sizeCheckDelta), - RequestHandler: args.RequestHandler, - PeerSignatureHandler: cryptoComponents.PeerSignatureHandler(), - SignaturesHandler: args.SignaturesHandler, - HeartbeatExpiryTimespanInSec: args.Config.HeartbeatV2.HeartbeatExpiryTimespanInSec, - MainPeerShardMapper: peerShardMapper, - FullArchivePeerShardMapper: fullArchivePeerShardMapper, - HardforkTrigger: hardforkTrigger, - NodeOperationMode: args.NodeOperationMode, + CoreComponents: args.CoreComponents, + CryptoComponents: cryptoComponents, + Accounts: accountsAdapter, + ShardCoordinator: args.ShardCoordinator, + NodesCoordinator: nodesCoordinator, + MainMessenger: args.MainMessenger, + FullArchiveMessenger: args.FullArchiveMessenger, + Store: storer, + DataPool: args.DataPool, + MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, + TxFeeHandler: feeHandler, + BlockBlackList: blackListHandler, + HeaderSigVerifier: headerSigVerifier, + HeaderIntegrityVerifier: args.HeaderIntegrityVerifier, + ValidityAttester: validityAttester, + EpochStartTrigger: epochStartTrigger, + WhiteListHandler: args.WhiteListHandler, + WhiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, + AntifloodHandler: antiFloodHandler, + ArgumentsParser: args.ArgumentsParser, + PreferredPeersHolder: disabled.NewPreferredPeersHolder(), + SizeCheckDelta: uint32(sizeCheckDelta), + RequestHandler: args.RequestHandler, + PeerSignatureHandler: cryptoComponents.PeerSignatureHandler(), + SignaturesHandler: args.SignaturesHandler, + HeartbeatExpiryTimespanInSec: args.Config.HeartbeatV2.HeartbeatExpiryTimespanInSec, + MainPeerShardMapper: peerShardMapper, + FullArchivePeerShardMapper: fullArchivePeerShardMapper, + HardforkTrigger: hardforkTrigger, + NodeOperationMode: args.NodeOperationMode, + InterceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, } interceptorsContainerFactory, err := interceptorscontainer.NewMetaInterceptorsContainerFactory(containerFactoryArgs) diff --git a/epochStart/bootstrap/fromLocalStorage.go b/epochStart/bootstrap/fromLocalStorage.go index 868d0359ef5..0572d3b376e 100644 --- a/epochStart/bootstrap/fromLocalStorage.go +++ b/epochStart/bootstrap/fromLocalStorage.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/bootstrap/disabled" @@ -130,6 +131,7 @@ func (e *epochStartBootstrap) prepareEpochFromStorage() (Parameters, error) { if err != nil { return Parameters{}, err } + e.requestHandler.SetEpoch(e.epochStartMeta.GetEpoch()) err = e.createSyncers() if err != nil { diff --git a/epochStart/bootstrap/interface.go b/epochStart/bootstrap/interface.go index bfc293032ee..2dabecc52a6 100644 --- a/epochStart/bootstrap/interface.go +++ b/epochStart/bootstrap/interface.go @@ -49,8 +49,12 @@ type Messenger interface { // RequestHandler defines which methods a request handler should implement type RequestHandler interface { RequestStartOfEpochMetaBlock(epoch uint32) + RequestMetaHeaderByNonce(nonce uint64) SetNumPeersToQuery(topic string, intra int, cross int) error GetNumPeersToQuery(topic string) (int, int, error) + RequestEquivalentProofByNonce(headerShard uint32, headerNonce uint64) + RequestEquivalentProofByHash(headerShard uint32, headerHash []byte) + SetEpoch(epoch uint32) IsInterfaceNil() bool } @@ -60,3 +64,12 @@ type NodeTypeProviderHandler interface { GetType() core.NodeType IsInterfaceNil() bool } + +// ProofsPool defines the behaviour of a proofs pool components +type ProofsPool interface { + RegisterHandler(handler func(headerProof data.HeaderProofHandler)) + GetProof(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) + GetProofByNonce(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) + HasProof(shardID uint32, headerHash []byte) bool + IsInterfaceNil() bool +} diff --git a/epochStart/bootstrap/metaStorageHandler.go b/epochStart/bootstrap/metaStorageHandler.go index 01f65ccabe6..82a63d856a3 100644 --- a/epochStart/bootstrap/metaStorageHandler.go +++ b/epochStart/bootstrap/metaStorageHandler.go @@ -61,6 +61,8 @@ func NewMetaStorageHandler(args StorageHandlerArgs) (*metaStorageHandler, error) currentEpoch: args.CurrentEpoch, uint64Converter: args.Uint64Converter, nodesCoordinatorRegistryFactory: args.NodesCoordinatorRegistryFactory, + proofsPool: args.ProofsPool, + enableEpochsHandler: args.EnableEpochsHandler, } return &metaStorageHandler{baseStorageHandler: base}, nil diff --git a/epochStart/bootstrap/metaStorageHandler_test.go b/epochStart/bootstrap/metaStorageHandler_test.go index 92603df176a..ba47388db5f 100644 --- a/epochStart/bootstrap/metaStorageHandler_test.go +++ b/epochStart/bootstrap/metaStorageHandler_test.go @@ -18,6 +18,8 @@ import ( "github.com/multiversx/mx-chain-go/process/block/bootstrapStorage" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" @@ -43,6 +45,8 @@ func createStorageHandlerArgs() StorageHandlerArgs { NodeProcessingMode: common.Normal, StateStatsHandler: disabled.NewStateStatistics(), RepopulateTokensSupplies: false, + ProofsPool: &dataRetrieverMocks.ProofsPoolMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } } diff --git a/epochStart/bootstrap/process.go b/epochStart/bootstrap/process.go index 101da829d73..81e010a7eca 100644 --- a/epochStart/bootstrap/process.go +++ b/epochStart/bootstrap/process.go @@ -14,6 +14,10 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/typeConverters/uint64ByteSlice" + logger "github.com/multiversx/mx-chain-logger-go" + + "github.com/multiversx/mx-chain-go/process/interceptors/processor" + "github.com/multiversx/mx-chain-go/common" disabledCommon "github.com/multiversx/mx-chain-go/common/disabled" "github.com/multiversx/mx-chain-go/common/ordering" @@ -52,7 +56,6 @@ import ( "github.com/multiversx/mx-chain-go/trie/storageMarker" "github.com/multiversx/mx-chain-go/update" updateSync "github.com/multiversx/mx-chain-go/update/sync" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("epochStart/bootstrap") @@ -121,6 +124,8 @@ type epochStartBootstrap struct { nodeProcessingMode common.NodeProcessingMode nodeOperationMode common.NodeOperation stateStatsHandler common.StateStatisticsHandler + enableEpochsHandler common.EnableEpochsHandler + // created components requestHandler process.RequestHandler mainInterceptorContainer process.InterceptorsContainer @@ -153,6 +158,8 @@ type epochStartBootstrap struct { nodeType core.NodeType startEpoch uint32 shuffledOut bool + + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory } type baseDataInStorage struct { @@ -191,6 +198,8 @@ type ArgsEpochStartBootstrap struct { NodeProcessingMode common.NodeProcessingMode StateStatsHandler common.StateStatisticsHandler NodesCoordinatorRegistryFactory nodesCoordinator.NodesCoordinatorRegistryFactory + EnableEpochsHandler common.EnableEpochsHandler + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory } type dataToSync struct { @@ -243,6 +252,8 @@ func NewEpochStartBootstrap(args ArgsEpochStartBootstrap) (*epochStartBootstrap, stateStatsHandler: args.StateStatsHandler, startEpoch: args.GeneralConfig.EpochStartConfig.GenesisEpoch, nodesCoordinatorRegistryFactory: args.NodesCoordinatorRegistryFactory, + enableEpochsHandler: args.EnableEpochsHandler, + interceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, } if epochStartProvider.prefsConfig.FullArchive { @@ -548,22 +559,27 @@ func (e *epochStartBootstrap) prepareComponentsToSyncFromNetwork() error { thresholdForConsideringMetaBlockCorrect, epochStartConfig.MinNumConnectedPeersToStart, epochStartConfig.MinNumOfPeersToConsiderBlockValid, + e.enableEpochsHandler, + e.dataPool.Proofs(), ) if err != nil { return err } argsEpochStartSyncer := ArgsNewEpochStartMetaSyncer{ - CoreComponentsHolder: e.coreComponentsHolder, - CryptoComponentsHolder: e.cryptoComponentsHolder, - RequestHandler: e.requestHandler, - Messenger: e.mainMessenger, - ShardCoordinator: e.shardCoordinator, - EconomicsData: e.economicsData, - WhitelistHandler: e.whiteListHandler, - StartInEpochConfig: epochStartConfig, - HeaderIntegrityVerifier: e.headerIntegrityVerifier, - MetaBlockProcessor: metaBlockProcessor, + CoreComponentsHolder: e.coreComponentsHolder, + CryptoComponentsHolder: e.cryptoComponentsHolder, + RequestHandler: e.requestHandler, + Messenger: e.mainMessenger, + ShardCoordinator: e.shardCoordinator, + EconomicsData: e.economicsData, + WhitelistHandler: e.whiteListHandler, + StartInEpochConfig: epochStartConfig, + HeaderIntegrityVerifier: e.headerIntegrityVerifier, + MetaBlockProcessor: metaBlockProcessor, + InterceptedDataVerifierFactory: e.interceptedDataVerifierFactory, + ProofsPool: e.dataPool.Proofs(), + ProofsInterceptorProcessor: processor.NewEquivalentProofsInterceptorProcessor(), } e.epochStartMetaBlockSyncer, err = NewEpochStartMetaSyncer(argsEpochStartSyncer) if err != nil { @@ -576,20 +592,21 @@ func (e *epochStartBootstrap) prepareComponentsToSyncFromNetwork() error { func (e *epochStartBootstrap) createSyncers() error { var err error args := factoryInterceptors.ArgsEpochStartInterceptorContainer{ - CoreComponents: e.coreComponentsHolder, - CryptoComponents: e.cryptoComponentsHolder, - Config: e.generalConfig, - ShardCoordinator: e.shardCoordinator, - MainMessenger: e.mainMessenger, - FullArchiveMessenger: e.fullArchiveMessenger, - DataPool: e.dataPool, - WhiteListHandler: e.whiteListHandler, - WhiteListerVerifiedTxs: e.whiteListerVerifiedTxs, - ArgumentsParser: e.argumentsParser, - HeaderIntegrityVerifier: e.headerIntegrityVerifier, - RequestHandler: e.requestHandler, - SignaturesHandler: e.mainMessenger, - NodeOperationMode: e.nodeOperationMode, + CoreComponents: e.coreComponentsHolder, + CryptoComponents: e.cryptoComponentsHolder, + Config: e.generalConfig, + ShardCoordinator: e.shardCoordinator, + MainMessenger: e.mainMessenger, + FullArchiveMessenger: e.fullArchiveMessenger, + DataPool: e.dataPool, + WhiteListHandler: e.whiteListHandler, + WhiteListerVerifiedTxs: e.whiteListerVerifiedTxs, + ArgumentsParser: e.argumentsParser, + HeaderIntegrityVerifier: e.headerIntegrityVerifier, + RequestHandler: e.requestHandler, + SignaturesHandler: e.mainMessenger, + NodeOperationMode: e.nodeOperationMode, + InterceptedDataVerifierFactory: e.interceptedDataVerifierFactory, } e.mainInterceptorContainer, e.fullArchiveInterceptorContainer, err = factoryInterceptors.NewEpochStartInterceptorsContainer(args) @@ -609,10 +626,12 @@ func (e *epochStartBootstrap) createSyncers() error { } syncMissingHeadersArgs := updateSync.ArgsNewMissingHeadersByHashSyncer{ - Storage: disabled.CreateMemUnit(), - Cache: e.dataPool.Headers(), - Marshalizer: e.coreComponentsHolder.InternalMarshalizer(), - RequestHandler: e.requestHandler, + Storage: disabled.CreateMemUnit(), + Cache: e.dataPool.Headers(), + ProofsPool: e.dataPool.Proofs(), + Marshalizer: e.coreComponentsHolder.InternalMarshalizer(), + RequestHandler: e.requestHandler, + EnableEpochsHandler: e.enableEpochsHandler, } e.headersSyncer, err = updateSync.NewMissingheadersByHashSyncer(syncMissingHeadersArgs) if err != nil { @@ -620,9 +639,11 @@ func (e *epochStartBootstrap) createSyncers() error { } epochStartShardHeaderSyncerArgs := updateSync.ArgsPendingEpochStartShardHeaderSyncer{ - HeadersPool: e.dataPool.Headers(), - Marshalizer: e.coreComponentsHolder.InternalMarshalizer(), - RequestHandler: e.requestHandler, + HeadersPool: e.dataPool.Headers(), + ProofsPool: e.dataPool.Proofs(), + Marshalizer: e.coreComponentsHolder.InternalMarshalizer(), + RequestHandler: e.requestHandler, + EnableEpochsHandler: e.enableEpochsHandler, } e.epochStartShardHeaderSyncer, err = updateSync.NewPendingEpochStartShardHeaderSyncer(epochStartShardHeaderSyncerArgs) if err != nil { @@ -647,7 +668,10 @@ func (e *epochStartBootstrap) createSyncers() error { func (e *epochStartBootstrap) syncHeadersFrom(meta data.MetaHeaderHandler) (map[string]data.HeaderHandler, error) { hashesToRequest := make([][]byte, 0, len(meta.GetEpochStartHandler().GetLastFinalizedHeaderHandlers())+1) shardIds := make([]uint32, 0, len(meta.GetEpochStartHandler().GetLastFinalizedHeaderHandlers())+1) - + epochStartMetaHash, err := core.CalculateHash(e.coreComponentsHolder.InternalMarshalizer(), e.coreComponentsHolder.Hasher(), meta) + if err != nil { + return nil, err + } for _, epochStartData := range meta.GetEpochStartHandler().GetLastFinalizedHeaderHandlers() { hashesToRequest = append(hashesToRequest, epochStartData.GetHeaderHash()) shardIds = append(shardIds, epochStartData.GetShardID()) @@ -658,8 +682,13 @@ func (e *epochStartBootstrap) syncHeadersFrom(meta data.MetaHeaderHandler) (map[ shardIds = append(shardIds, core.MetachainShardId) } + // add the epoch start meta hash to the list to sync its proof + // TODO: this can be removed when the proof will be loaded from storage + hashesToRequest = append(hashesToRequest, epochStartMetaHash) + shardIds = append(shardIds, core.MetachainShardId) + ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeToWaitForRequestedData) - err := e.headersSyncer.SyncMissingHeadersByHash(shardIds, hashesToRequest, ctx) + err = e.headersSyncer.SyncMissingHeadersByHash(shardIds, hashesToRequest, ctx) cancel() if err != nil { return nil, err @@ -677,7 +706,7 @@ func (e *epochStartBootstrap) syncHeadersFrom(meta data.MetaHeaderHandler) (map[ return syncedHeaders, nil } -// Bootstrap will handle requesting and receiving the needed information the node will bootstrap from +// requestAndProcessing will handle requesting and receiving the needed information the node will bootstrap from func (e *epochStartBootstrap) requestAndProcessing() (Parameters, error) { var err error e.baseData.numberOfShards = uint32(len(e.epochStartMeta.GetEpochStartHandler().GetLastFinalizedHeaderHandlers())) @@ -775,6 +804,7 @@ func (e *epochStartBootstrap) processNodesConfig(pubKey []byte) ([]*block.MiniBl RequestHandler: e.requestHandler, ChanceComputer: e.rater, GenesisNodesConfig: e.genesisNodesConfig, + ChainParametersHandler: e.coreComponentsHolder.ChainParametersHandler(), NodeShuffler: e.nodeShuffler, Hasher: e.coreComponentsHolder.Hasher(), PubKey: pubKey, @@ -815,6 +845,8 @@ func (e *epochStartBootstrap) requestAndProcessForMeta(peerMiniBlocks []*block.M ManagedPeersHolder: e.cryptoComponentsHolder.ManagedPeersHolder(), NodeProcessingMode: e.nodeProcessingMode, StateStatsHandler: e.stateStatsHandler, + ProofsPool: e.dataPool.Proofs(), + EnableEpochsHandler: e.enableEpochsHandler, } storageHandlerComponent, err := NewMetaStorageHandler(argsStorageHandler) if err != nil { @@ -1036,6 +1068,8 @@ func (e *epochStartBootstrap) requestAndProcessForShard(peerMiniBlocks []*block. ManagedPeersHolder: e.cryptoComponentsHolder.ManagedPeersHolder(), NodeProcessingMode: e.nodeProcessingMode, StateStatsHandler: e.stateStatsHandler, + ProofsPool: e.dataPool.Proofs(), + EnableEpochsHandler: e.enableEpochsHandler, } storageHandlerComponent, err := NewShardStorageHandler(argsStorageHandler) if err != nil { @@ -1335,6 +1369,7 @@ func (e *epochStartBootstrap) createRequestHandler() error { FullArchivePreferredPeersHolder: disabled.NewPreferredPeersHolder(), PeersRatingHandler: disabled.NewDisabledPeersRatingHandler(), SizeCheckDelta: 0, + EnableEpochsHandler: e.enableEpochsHandler, } requestersFactory, err := requesterscontainer.NewMetaRequestersContainerFactory(requestersContainerArgs) if err != nil { diff --git a/epochStart/bootstrap/process_test.go b/epochStart/bootstrap/process_test.go index fca682151eb..51a36dd92a4 100644 --- a/epochStart/bootstrap/process_test.go +++ b/epochStart/bootstrap/process_test.go @@ -20,6 +20,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/graceperiod" "github.com/multiversx/mx-chain-go/common/statistics" disabledStatistics "github.com/multiversx/mx-chain-go/common/statistics/disabled" "github.com/multiversx/mx-chain-go/config" @@ -29,12 +30,15 @@ import ( "github.com/multiversx/mx-chain-go/epochStart/bootstrap/types" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/process" + processMock "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" epochStartMocks "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks/epochStart" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" @@ -74,6 +78,17 @@ func createPkBytes(numShards uint32) map[uint32][]byte { } func createComponentsForEpochStart() (*mock.CoreComponentsMock, *mock.CryptoComponentsMock) { + chainParams := &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + } + + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) + return &mock.CoreComponentsMock{ IntMarsh: &mock.MarshalizerMock{}, Marsh: &mock.MarshalizerMock{}, @@ -95,6 +110,8 @@ func createComponentsForEpochStart() (*mock.CoreComponentsMock, *mock.CryptoComp return 0 }, }, + EpochChangeGracePeriodHandlerField: gracePeriod, + ChainParametersHandlerField: chainParams, }, &mock.CryptoComponentsMock{ PubKey: &cryptoMocks.PublicKeyStub{}, @@ -143,6 +160,7 @@ func createMockEpochStartBootstrapArgs( PeerAccountsTrieStorage: generalCfg.PeerAccountsTrieStorage, HeartbeatV2: generalCfg.HeartbeatV2, Hardfork: generalCfg.Hardfork, + ProofsStorage: generalCfg.ProofsStorage, EvictionWaitingList: config.EvictionWaitingListConfig{ HashesSize: 100, RootHashesSize: 100, @@ -240,8 +258,10 @@ func createMockEpochStartBootstrapArgs( FlagsConfig: config.ContextFlagsConfig{ ForceStartFromNetwork: false, }, - TrieSyncStatisticsProvider: &testscommon.SizeSyncStatisticsHandlerStub{}, - StateStatsHandler: disabledStatistics.NewStateStatistics(), + TrieSyncStatisticsProvider: &testscommon.SizeSyncStatisticsHandlerStub{}, + StateStatsHandler: disabledStatistics.NewStateStatistics(), + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + InterceptedDataVerifierFactory: &processMock.InterceptedDataVerifierFactoryMock{}, } } @@ -972,22 +992,26 @@ func TestCreateSyncers(t *testing.T) { return testscommon.NewShardedDataStub() }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TrieNodesCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, PeerAuthenticationsCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, HeartbeatsCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() + }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} }, } epochStartProvider.whiteListHandler = &testscommon.WhiteListHandlerStub{} epochStartProvider.whiteListerVerifiedTxs = &testscommon.WhiteListHandlerStub{} epochStartProvider.requestHandler = &testscommon.RequestHandlerStub{} epochStartProvider.storageService = &storageMocks.ChainStorerStub{} + epochStartProvider.interceptedDataVerifierFactory = &processMock.InterceptedDataVerifierFactoryMock{} err := epochStartProvider.createSyncers() assert.Nil(t, err) @@ -1038,7 +1062,7 @@ func TestSyncValidatorAccountsState_NilRequestHandlerErr(t *testing.T) { epochStartProvider, _ := NewEpochStartBootstrap(args) epochStartProvider.dataPool = &dataRetrieverMock.PoolsHolderStub{ TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -1084,7 +1108,7 @@ func TestSyncUserAccountsState(t *testing.T) { epochStartProvider.shardCoordinator = mock.NewMultipleShardsCoordinatorMock() epochStartProvider.dataPool = &dataRetrieverMock.PoolsHolderStub{ TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -1331,6 +1355,9 @@ func TestRequestAndProcessForShard_ShouldFail(t *testing.T) { TrieNodesCalled: func() storage.Cacher { return nil }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } epochStartProvider.miniBlocksSyncer = &epochStartMocks.PendingMiniBlockSyncHandlerStub{} @@ -1400,12 +1427,15 @@ func TestRequestAndProcessForShard_ShouldFail(t *testing.T) { } epochStartProvider.dataPool = &dataRetrieverMock.PoolsHolderStub{ TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, } }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } epochStartProvider.miniBlocksSyncer = &epochStartMocks.PendingMiniBlockSyncHandlerStub{} @@ -1437,6 +1467,11 @@ func TestRequestAndProcessForMeta_ShouldFail(t *testing.T) { epochStartProvider, _ := NewEpochStartBootstrap(args) epochStartProvider.epochStartMeta = metaBlock + epochStartProvider.dataPool = &dataRetrieverMock.PoolsHolderStub{ + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, + } epochStartProvider.shardCoordinator = nil @@ -1514,12 +1549,15 @@ func TestRequestAndProcessForMeta_ShouldFail(t *testing.T) { } epochStartProvider.dataPool = &dataRetrieverMock.PoolsHolderStub{ TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, } }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } epochStartProvider.miniBlocksSyncer = &epochStartMocks.PendingMiniBlockSyncHandlerStub{} @@ -1620,11 +1658,14 @@ func TestRequestAndProcessing(t *testing.T) { epochStartProvider, _ := NewEpochStartBootstrap(args) epochStartProvider.epochStartMeta = epochStartMetaBlock + epochStartMetaHash, err := core.CalculateHash(epochStartProvider.coreComponentsHolder.InternalMarshalizer(), epochStartProvider.coreComponentsHolder.Hasher(), epochStartMetaBlock) + require.Nil(t, err) + expectedErr := errors.New("sync miniBlocksSyncer headers by hash error") epochStartProvider.headersSyncer = &epochStartMocks.HeadersByHashSyncerStub{ SyncMissingHeadersByHashCalled: func(shardIDs []uint32, headersHashes [][]byte, ctx context.Context) error { - assert.Equal(t, [][]byte{notarizedShardHeaderHash}, headersHashes) - assert.Equal(t, []uint32{shardId}, shardIDs) + assert.Equal(t, [][]byte{notarizedShardHeaderHash, epochStartMetaHash}, headersHashes) + assert.Equal(t, []uint32{shardId, core.MetachainShardId}, shardIDs) return expectedErr }, } @@ -1884,10 +1925,10 @@ func TestRequestAndProcessing(t *testing.T) { } epochStartProvider.dataPool = &dataRetrieverMock.PoolsHolderStub{ MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -1899,6 +1940,9 @@ func TestRequestAndProcessing(t *testing.T) { CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { return &validatorInfoCacherStub.ValidatorInfoCacherStub{} }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } epochStartProvider.requestHandler = &testscommon.RequestHandlerStub{} epochStartProvider.miniBlocksSyncer = &epochStartMocks.PendingMiniBlockSyncHandlerStub{} @@ -1954,10 +1998,10 @@ func TestRequestAndProcessing(t *testing.T) { } epochStartProvider.dataPool = &dataRetrieverMock.PoolsHolderStub{ MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -1969,6 +2013,9 @@ func TestRequestAndProcessing(t *testing.T) { CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { return &validatorInfoCacherStub.ValidatorInfoCacherStub{} }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } epochStartProvider.requestHandler = &testscommon.RequestHandlerStub{} epochStartProvider.miniBlocksSyncer = &epochStartMocks.PendingMiniBlockSyncHandlerStub{} @@ -2124,10 +2171,10 @@ func TestEpochStartBootstrap_WithDisabledShardIDAsObserver(t *testing.T) { return testscommon.NewShardedDataStub() }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TrieNodesCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { return &validatorInfoCacherStub.ValidatorInfoCacherStub{} @@ -2460,16 +2507,19 @@ func TestSyncSetGuardianTransaction(t *testing.T) { return testscommon.NewShardedDataStub() }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TrieNodesCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, PeerAuthenticationsCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, HeartbeatsCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() + }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} }, } epochStartProvider.whiteListHandler = &testscommon.WhiteListHandlerStub{ @@ -2517,8 +2567,9 @@ func TestSyncSetGuardianTransaction(t *testing.T) { TimestampField: 0, } - err = interceptor.ProcessReceivedMessage(msg, "pid", nil) + msgID, err := interceptor.ProcessReceivedMessage(msg, "pid", nil) assert.Nil(t, err) + assert.NotNil(t, msgID) time.Sleep(time.Second) diff --git a/epochStart/bootstrap/shardStorageHandler.go b/epochStart/bootstrap/shardStorageHandler.go index b10abef815c..469089ee973 100644 --- a/epochStart/bootstrap/shardStorageHandler.go +++ b/epochStart/bootstrap/shardStorageHandler.go @@ -65,6 +65,8 @@ func NewShardStorageHandler(args StorageHandlerArgs) (*shardStorageHandler, erro currentEpoch: args.CurrentEpoch, uint64Converter: args.Uint64Converter, nodesCoordinatorRegistryFactory: args.NodesCoordinatorRegistryFactory, + proofsPool: args.ProofsPool, + enableEpochsHandler: args.EnableEpochsHandler, } return &shardStorageHandler{baseStorageHandler: base}, nil diff --git a/epochStart/bootstrap/storageProcess.go b/epochStart/bootstrap/storageProcess.go index 809b0dfbb8b..1004583e07b 100644 --- a/epochStart/bootstrap/storageProcess.go +++ b/epochStart/bootstrap/storageProcess.go @@ -11,6 +11,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/multiversx/mx-chain-go/process/interceptors/processor" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -177,16 +179,19 @@ func (sesb *storageEpochStartBootstrap) prepareComponentsToSync() error { } argsEpochStartSyncer := ArgsNewEpochStartMetaSyncer{ - CoreComponentsHolder: sesb.coreComponentsHolder, - CryptoComponentsHolder: sesb.cryptoComponentsHolder, - RequestHandler: sesb.requestHandler, - Messenger: sesb.mainMessenger, - ShardCoordinator: sesb.shardCoordinator, - EconomicsData: sesb.economicsData, - WhitelistHandler: sesb.whiteListHandler, - StartInEpochConfig: sesb.generalConfig.EpochStartConfig, - HeaderIntegrityVerifier: sesb.headerIntegrityVerifier, - MetaBlockProcessor: metablockProcessor, + CoreComponentsHolder: sesb.coreComponentsHolder, + CryptoComponentsHolder: sesb.cryptoComponentsHolder, + RequestHandler: sesb.requestHandler, + Messenger: sesb.mainMessenger, + ShardCoordinator: sesb.shardCoordinator, + EconomicsData: sesb.economicsData, + WhitelistHandler: sesb.whiteListHandler, + StartInEpochConfig: sesb.generalConfig.EpochStartConfig, + HeaderIntegrityVerifier: sesb.headerIntegrityVerifier, + MetaBlockProcessor: metablockProcessor, + InterceptedDataVerifierFactory: sesb.interceptedDataVerifierFactory, + ProofsPool: sesb.dataPool.Proofs(), + ProofsInterceptorProcessor: processor.NewEquivalentProofsInterceptorProcessor(), } sesb.epochStartMetaBlockSyncer, err = NewEpochStartMetaSyncer(argsEpochStartSyncer) @@ -409,6 +414,7 @@ func (sesb *storageEpochStartBootstrap) processNodesConfig(pubKey []byte) error RequestHandler: sesb.requestHandler, ChanceComputer: sesb.rater, GenesisNodesConfig: sesb.genesisNodesConfig, + ChainParametersHandler: sesb.coreComponentsHolder.ChainParametersHandler(), NodeShuffler: sesb.nodeShuffler, Hasher: sesb.coreComponentsHolder.Hasher(), PubKey: pubKey, diff --git a/epochStart/bootstrap/storageProcess_test.go b/epochStart/bootstrap/storageProcess_test.go index 161681f744e..34a7f97cbcc 100644 --- a/epochStart/bootstrap/storageProcess_test.go +++ b/epochStart/bootstrap/storageProcess_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/process" + processMock "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" @@ -128,6 +129,7 @@ func TestStorageEpochStartBootstrap_BootstrapMetablockNotFound(t *testing.T) { } args.GeneralConfig = testscommon.GetGeneralConfig() args.GeneralConfig.EpochStartConfig.RoundsPerEpoch = roundsPerEpoch + args.InterceptedDataVerifierFactory = &processMock.InterceptedDataVerifierFactoryMock{} sesb, _ := NewStorageEpochStartBootstrap(args) params, err := sesb.Bootstrap() diff --git a/epochStart/bootstrap/syncEpochStartMeta.go b/epochStart/bootstrap/syncEpochStartMeta.go index fa764a04c4a..07b3e9f1bd1 100644 --- a/epochStart/bootstrap/syncEpochStartMeta.go +++ b/epochStart/bootstrap/syncEpochStartMeta.go @@ -4,12 +4,15 @@ import ( "context" "time" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/bootstrap/disabled" "github.com/multiversx/mx-chain-go/process" @@ -22,27 +25,32 @@ import ( var _ epochStart.StartOfEpochMetaSyncer = (*epochStartMetaSyncer)(nil) type epochStartMetaSyncer struct { - requestHandler RequestHandler - messenger Messenger - marshalizer marshal.Marshalizer - hasher hashing.Hasher - singleDataInterceptor process.Interceptor - metaBlockProcessor EpochStartMetaBlockInterceptorProcessor + requestHandler RequestHandler + messenger Messenger + marshalizer marshal.Marshalizer + hasher hashing.Hasher + singleDataInterceptor process.Interceptor + proofsInterceptor process.Interceptor + metaBlockProcessor EpochStartMetaBlockInterceptorProcessor + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // ArgsNewEpochStartMetaSyncer - type ArgsNewEpochStartMetaSyncer struct { - CoreComponentsHolder process.CoreComponentsHolder - CryptoComponentsHolder process.CryptoComponentsHolder - RequestHandler RequestHandler - Messenger Messenger - ShardCoordinator sharding.Coordinator - EconomicsData process.EconomicsDataHandler - WhitelistHandler process.WhiteListHandler - StartInEpochConfig config.EpochStartConfig - ArgsParser process.ArgumentsParser - HeaderIntegrityVerifier process.HeaderIntegrityVerifier - MetaBlockProcessor EpochStartMetaBlockInterceptorProcessor + CoreComponentsHolder process.CoreComponentsHolder + CryptoComponentsHolder process.CryptoComponentsHolder + RequestHandler RequestHandler + Messenger Messenger + ShardCoordinator sharding.Coordinator + EconomicsData process.EconomicsDataHandler + WhitelistHandler process.WhiteListHandler + StartInEpochConfig config.EpochStartConfig + ArgsParser process.ArgumentsParser + HeaderIntegrityVerifier process.HeaderIntegrityVerifier + MetaBlockProcessor EpochStartMetaBlockInterceptorProcessor + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory + ProofsPool dataRetriever.ProofsPool + ProofsInterceptorProcessor process.InterceptorProcessor } // NewEpochStartMetaSyncer will return a new instance of epochStartMetaSyncer @@ -62,13 +70,20 @@ func NewEpochStartMetaSyncer(args ArgsNewEpochStartMetaSyncer) (*epochStartMetaS if check.IfNil(args.MetaBlockProcessor) { return nil, epochStart.ErrNilMetablockProcessor } + if check.IfNil(args.InterceptedDataVerifierFactory) { + return nil, epochStart.ErrNilInterceptedDataVerifierFactory + } + if check.IfNil(args.ProofsInterceptorProcessor) { + return nil, epochStart.ErrNilEquivalentProofsProcessor + } e := &epochStartMetaSyncer{ - requestHandler: args.RequestHandler, - messenger: args.Messenger, - marshalizer: args.CoreComponentsHolder.InternalMarshalizer(), - hasher: args.CoreComponentsHolder.Hasher(), - metaBlockProcessor: args.MetaBlockProcessor, + requestHandler: args.RequestHandler, + messenger: args.Messenger, + marshalizer: args.CoreComponentsHolder.InternalMarshalizer(), + hasher: args.CoreComponentsHolder.Hasher(), + metaBlockProcessor: args.MetaBlockProcessor, + interceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, } argsInterceptedDataFactory := interceptorsFactory.ArgInterceptedDataFactory{ @@ -83,22 +98,58 @@ func NewEpochStartMetaSyncer(args ArgsNewEpochStartMetaSyncer) (*epochStartMetaS EpochStartTrigger: disabled.NewEpochStartTrigger(), ArgsParser: args.ArgsParser, } + argsInterceptedMetaHeaderFactory := interceptorsFactory.ArgInterceptedMetaHeaderFactory{ + ArgInterceptedDataFactory: argsInterceptedDataFactory, + } + + interceptedMetaHdrDataFactory, err := interceptorsFactory.NewInterceptedMetaHeaderDataFactory(&argsInterceptedMetaHeaderFactory) + if err != nil { + return nil, err + } - interceptedMetaHdrDataFactory, err := interceptorsFactory.NewInterceptedMetaHeaderDataFactory(&argsInterceptedDataFactory) + interceptedDataVerifier, err := e.interceptedDataVerifierFactory.Create(factory.MetachainBlocksTopic) if err != nil { return nil, err } e.singleDataInterceptor, err = interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: factory.MetachainBlocksTopic, - DataFactory: interceptedMetaHdrDataFactory, - Processor: args.MetaBlockProcessor, - Throttler: disabled.NewThrottler(), - AntifloodHandler: disabled.NewAntiFloodHandler(), - WhiteListRequest: args.WhitelistHandler, - CurrentPeerId: args.Messenger.ID(), - PreferredPeersHolder: disabled.NewPreferredPeersHolder(), + Topic: factory.MetachainBlocksTopic, + DataFactory: interceptedMetaHdrDataFactory, + Processor: args.MetaBlockProcessor, + Throttler: disabled.NewThrottler(), + AntifloodHandler: disabled.NewAntiFloodHandler(), + WhiteListRequest: args.WhitelistHandler, + CurrentPeerId: args.Messenger.ID(), + PreferredPeersHolder: disabled.NewPreferredPeersHolder(), + InterceptedDataVerifier: interceptedDataVerifier, + }, + ) + if err != nil { + return nil, err + } + + argsInterceptedEquivalentProofsFactory := interceptorsFactory.ArgInterceptedEquivalentProofsFactory{ + ArgInterceptedDataFactory: argsInterceptedDataFactory, + ProofsPool: args.ProofsPool, + } + interceptedEquivalentProofsFactory := interceptorsFactory.NewInterceptedEquivalentProofsFactory(argsInterceptedEquivalentProofsFactory) + if err != nil { + return nil, err + } + + proofsTopic := common.EquivalentProofsTopic + core.CommunicationIdentifierBetweenShards(core.MetachainShardId, core.AllShardId) + e.proofsInterceptor, err = interceptors.NewSingleDataInterceptor( + interceptors.ArgSingleDataInterceptor{ + Topic: proofsTopic, + DataFactory: interceptedEquivalentProofsFactory, + Processor: args.ProofsInterceptorProcessor, + Throttler: disabled.NewThrottler(), + AntifloodHandler: disabled.NewAntiFloodHandler(), + WhiteListRequest: args.WhitelistHandler, + CurrentPeerId: args.Messenger.ID(), + PreferredPeersHolder: disabled.NewPreferredPeersHolder(), + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -134,6 +185,12 @@ func (e *epochStartMetaSyncer) resetTopicsAndInterceptors() { if err != nil { log.Trace("error unregistering message processors", "error", err) } + + proofsTopic := common.EquivalentProofsTopic + core.CommunicationIdentifierBetweenShards(core.MetachainShardId, core.AllShardId) + err = e.messenger.UnregisterMessageProcessor(proofsTopic, common.EpochStartInterceptorsIdentifier) + if err != nil { + log.Trace("error unregistering message processors", "error", err) + } } func (e *epochStartMetaSyncer) initTopicForEpochStartMetaBlockInterceptor() error { @@ -143,13 +200,20 @@ func (e *epochStartMetaSyncer) initTopicForEpochStartMetaBlockInterceptor() erro return err } + proofsTopic := common.EquivalentProofsTopic + core.CommunicationIdentifierBetweenShards(core.MetachainShardId, core.AllShardId) + err = e.messenger.CreateTopic(proofsTopic, true) + if err != nil { + log.Warn("error messenger create topic", "topic", proofsTopic, "error", err) + return err + } + e.resetTopicsAndInterceptors() err = e.messenger.RegisterMessageProcessor(factory.MetachainBlocksTopic, common.EpochStartInterceptorsIdentifier, e.singleDataInterceptor) if err != nil { return err } - return nil + return e.messenger.RegisterMessageProcessor(proofsTopic, common.EpochStartInterceptorsIdentifier, e.proofsInterceptor) } // IsInterfaceNil returns true if underlying object is nil diff --git a/epochStart/bootstrap/syncEpochStartMeta_test.go b/epochStart/bootstrap/syncEpochStartMeta_test.go index c85efc7304c..8edc24825fd 100644 --- a/epochStart/bootstrap/syncEpochStartMeta_test.go +++ b/epochStart/bootstrap/syncEpochStartMeta_test.go @@ -9,17 +9,21 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common/graceperiod" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/p2p" + processMock "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestNewEpochStartMetaSyncer_NilsShouldError(t *testing.T) { @@ -48,6 +52,12 @@ func TestNewEpochStartMetaSyncer_NilsShouldError(t *testing.T) { ess, err = NewEpochStartMetaSyncer(args) assert.True(t, check.IfNil(ess)) assert.Equal(t, epochStart.ErrNilMetablockProcessor, err) + + args = getEpochStartSyncerArgs() + args.InterceptedDataVerifierFactory = nil + ess, err = NewEpochStartMetaSyncer(args) + assert.True(t, check.IfNil(ess)) + assert.Equal(t, epochStart.ErrNilInterceptedDataVerifierFactory, err) } func TestNewEpochStartMetaSyncer_ShouldWork(t *testing.T) { @@ -71,7 +81,8 @@ func TestEpochStartMetaSyncer_SyncEpochStartMetaRegisterMessengerProcessorFailsS }, } args.Messenger = messenger - ess, _ := NewEpochStartMetaSyncer(args) + ess, err := NewEpochStartMetaSyncer(args) + require.NoError(t, err) mb, err := ess.SyncEpochStartMeta(time.Second) require.Equal(t, expectedErr, err) @@ -131,6 +142,7 @@ func TestEpochStartMetaSyncer_SyncEpochStartMetaShouldWork(t *testing.T) { } func getEpochStartSyncerArgs() ArgsNewEpochStartMetaSyncer { + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) return ArgsNewEpochStartMetaSyncer{ CoreComponentsHolder: &mock.CoreComponentsMock{ IntMarsh: &mock.MarshalizerMock{}, @@ -142,6 +154,7 @@ func getEpochStartSyncerArgs() ArgsNewEpochStartMetaSyncer { ChainIdCalled: func() string { return "chain-ID" }, + EpochChangeGracePeriodHandlerField: gracePeriod, }, CryptoComponentsHolder: &mock.CryptoComponentsMock{ PubKey: &cryptoMocks.PublicKeyStub{}, @@ -159,7 +172,10 @@ func getEpochStartSyncerArgs() ArgsNewEpochStartMetaSyncer { MinNumConnectedPeersToStart: 2, MinNumOfPeersToConsiderBlockValid: 2, }, - HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, - MetaBlockProcessor: &mock.EpochStartMetaBlockProcessorStub{}, + HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, + MetaBlockProcessor: &mock.EpochStartMetaBlockProcessorStub{}, + InterceptedDataVerifierFactory: &processMock.InterceptedDataVerifierFactoryMock{}, + ProofsPool: &dataRetriever.ProofsPoolMock{}, + ProofsInterceptorProcessor: &processMock.InterceptorProcessorStub{}, } } diff --git a/epochStart/bootstrap/syncValidatorStatus.go b/epochStart/bootstrap/syncValidatorStatus.go index 0bcb9308311..4a0883f51af 100644 --- a/epochStart/bootstrap/syncValidatorStatus.go +++ b/epochStart/bootstrap/syncValidatorStatus.go @@ -44,6 +44,7 @@ type ArgsNewSyncValidatorStatus struct { RequestHandler process.RequestHandler ChanceComputer nodesCoordinator.ChanceComputer GenesisNodesConfig sharding.GenesisNodesSetupHandler + ChainParametersHandler process.ChainParametersHandler NodeShuffler nodesCoordinator.NodesShuffler PubKey []byte ShardIdAsObserver uint32 @@ -112,8 +113,7 @@ func NewSyncValidatorStatus(args ArgsNewSyncValidatorStatus) (*syncValidatorStat s.memDB = disabled.CreateMemUnit() argsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ - ShardConsensusGroupSize: int(args.GenesisNodesConfig.GetShardConsensusGroupSize()), - MetaConsensusGroupSize: int(args.GenesisNodesConfig.GetMetaConsensusGroupSize()), + ChainParametersHandler: args.ChainParametersHandler, Marshalizer: args.Marshalizer, Hasher: args.Hasher, Shuffler: args.NodeShuffler, diff --git a/epochStart/bootstrap/syncValidatorStatus_test.go b/epochStart/bootstrap/syncValidatorStatus_test.go index 7cfe6061c77..ee8b7c02dae 100644 --- a/epochStart/bootstrap/syncValidatorStatus_test.go +++ b/epochStart/bootstrap/syncValidatorStatus_test.go @@ -9,12 +9,17 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" epochStartMocks "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks/epochStart" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" @@ -22,8 +27,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" vic "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const initRating = uint32(50) @@ -255,16 +258,17 @@ func getSyncValidatorStatusArgs() ArgsNewSyncValidatorStatus { return ArgsNewSyncValidatorStatus{ DataPool: &dataRetrieverMock.PoolsHolderStub{ MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { return &vic.ValidatorInfoCacherStub{} }, }, - Marshalizer: &mock.MarshalizerMock{}, - Hasher: &hashingMocks.HasherMock{}, - RequestHandler: &testscommon.RequestHandlerStub{}, - ChanceComputer: &shardingMocks.NodesCoordinatorStub{}, + Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + RequestHandler: &testscommon.RequestHandlerStub{}, + ChanceComputer: &shardingMocks.NodesCoordinatorStub{}, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{}, GenesisNodesConfig: &genesisMocks.NodesSetupStub{ NumberOfShardsCalled: func() uint32 { return 1 diff --git a/epochStart/errors.go b/epochStart/errors.go index 650c94bfe35..6c567b838f4 100644 --- a/epochStart/errors.go +++ b/epochStart/errors.go @@ -71,6 +71,9 @@ var ErrNilRequestHandler = errors.New("nil request handler") // ErrNilMetaBlocksPool signals that nil metablock pools holder has been provided var ErrNilMetaBlocksPool = errors.New("nil metablocks pool") +// ErrNilProofsPool signals that nil proofs pool has been provided +var ErrNilProofsPool = errors.New("nil proofs pool") + // ErrNilValidatorInfoProcessor signals that a nil validator info processor has been provided var ErrNilValidatorInfoProcessor = errors.New("nil validator info processor") @@ -233,6 +236,9 @@ var ErrNilEpochNotifier = errors.New("nil EpochNotifier") // ErrNilMetablockProcessor signals that a nil metablock processor was provided var ErrNilMetablockProcessor = errors.New("nil metablock processor") +// ErrNilInterceptedDataVerifierFactory signals that a nil intercepted data verifier factory was provided +var ErrNilInterceptedDataVerifierFactory = errors.New("nil intercepted data verifier factory") + // ErrCouldNotInitDelegationSystemSC signals that delegation system sc init failed var ErrCouldNotInitDelegationSystemSC = errors.New("could not init delegation system sc") @@ -346,3 +352,9 @@ var ErrUint32SubtractionOverflow = errors.New("uint32 subtraction overflowed") // ErrReceivedAuctionValidatorsBeforeStakingV4 signals that an auction node has been provided before enabling staking v4 var ErrReceivedAuctionValidatorsBeforeStakingV4 = errors.New("auction node has been provided before enabling staking v4") + +// ErrNilEquivalentProofsProcessor signals that a nil equivalent proofs processor was provided +var ErrNilEquivalentProofsProcessor = errors.New("nil equivalent proofs processor") + +// ErrNilHeadersDataPool signals that a nil headers pool has been provided +var ErrNilHeadersDataPool = errors.New("nil headers data pool") diff --git a/epochStart/interface.go b/epochStart/interface.go index 8c171dc6c3d..5daf24f50a7 100644 --- a/epochStart/interface.go +++ b/epochStart/interface.go @@ -7,9 +7,10 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/state" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) // TriggerHandler defines the functionalities for an start of epoch trigger @@ -126,7 +127,7 @@ type StartOfEpochMetaSyncer interface { // NodesConfigProvider will provide the necessary information for start in epoch economics block creation type NodesConfigProvider interface { - ConsensusGroupSize(shardID uint32) int + ConsensusGroupSizeForShardAndEpoch(shardID uint32, epoch uint32) int IsInterfaceNil() bool } diff --git a/epochStart/metachain/baseRewards.go b/epochStart/metachain/baseRewards.go index 691dbb1aa88..cde168af06f 100644 --- a/epochStart/metachain/baseRewards.go +++ b/epochStart/metachain/baseRewards.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/rewardTx" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool" @@ -56,7 +57,7 @@ type baseRewardsCreator struct { mapBaseRewardsPerBlockPerValidator map[uint32]*big.Int accumulatedRewards *big.Int protocolSustainabilityValue *big.Int - flagDelegationSystemSCEnabled atomic.Flag //nolint + flagDelegationSystemSCEnabled atomic.Flag // nolint userAccountsDB state.AccountsAdapter enableEpochsHandler common.EnableEpochsHandler mutRewardsData sync.RWMutex @@ -441,19 +442,29 @@ func (brc *baseRewardsCreator) addExecutionOrdering(txHashes [][]byte) { } } -func (brc *baseRewardsCreator) fillBaseRewardsPerBlockPerNode(baseRewardsPerNode *big.Int) { +func (brc *baseRewardsCreator) fillBaseRewardsPerBlockPerNode(baseRewardsPerNode *big.Int, epoch uint32) { brc.mapBaseRewardsPerBlockPerValidator = make(map[uint32]*big.Int) for i := uint32(0); i < brc.shardCoordinator.NumberOfShards(); i++ { - consensusSize := big.NewInt(int64(brc.nodesConfigProvider.ConsensusGroupSize(i))) + consensusSize := big.NewInt(int64(brc.getConsensusGroupSizeForShardAndEpoch(i, epoch))) brc.mapBaseRewardsPerBlockPerValidator[i] = big.NewInt(0).Div(baseRewardsPerNode, consensusSize) log.Debug("baseRewardsPerBlockPerValidator", "shardID", i, "value", brc.mapBaseRewardsPerBlockPerValidator[i].String()) } - consensusSize := big.NewInt(int64(brc.nodesConfigProvider.ConsensusGroupSize(core.MetachainShardId))) + consensusSize := big.NewInt(int64(brc.getConsensusGroupSizeForShardAndEpoch(core.MetachainShardId, epoch))) brc.mapBaseRewardsPerBlockPerValidator[core.MetachainShardId] = big.NewInt(0).Div(baseRewardsPerNode, consensusSize) log.Debug("baseRewardsPerBlockPerValidator", "shardID", core.MetachainShardId, "value", brc.mapBaseRewardsPerBlockPerValidator[core.MetachainShardId].String()) } +func (brc *baseRewardsCreator) getConsensusGroupSizeForShardAndEpoch(shardID uint32, epoch uint32) int { + if epoch == 0 { + return brc.nodesConfigProvider.ConsensusGroupSizeForShardAndEpoch(shardID, 0) + } + + // use previous epoch for fetching the consensus group size, since the epoch start metablock already contains the new epoch + epochForConsensusSize := epoch - 1 + return brc.nodesConfigProvider.ConsensusGroupSizeForShardAndEpoch(shardID, epochForConsensusSize) +} + func (brc *baseRewardsCreator) verifyCreatedRewardMiniBlocksWithMetaBlock(metaBlock data.HeaderHandler, createdMiniBlocks block.MiniBlockSlice) error { numReceivedRewardsMBs := 0 for _, miniBlockHdr := range metaBlock.GetMiniBlockHeaderHandlers() { diff --git a/epochStart/metachain/baseRewards_test.go b/epochStart/metachain/baseRewards_test.go index 2d36f2a2e3a..da21d5d9f19 100644 --- a/epochStart/metachain/baseRewards_test.go +++ b/epochStart/metachain/baseRewards_test.go @@ -1001,15 +1001,23 @@ func TestBaseRewardsCreator_finalizeMiniBlocksEmptyMbsAreRemoved(t *testing.T) { func TestBaseRewardsCreator_fillBaseRewardsPerBlockPerNode(t *testing.T) { t.Parallel() + // should work for epoch 0 even if this is a bad input + testFillBaseRewardsPerBlockPerNode(t, 0) + + // should work for an epoch higher than 0 + testFillBaseRewardsPerBlockPerNode(t, 1) +} + +func testFillBaseRewardsPerBlockPerNode(t *testing.T, epoch uint32) { args := getBaseRewardsArguments() rwd, err := NewBaseRewardsCreator(args) require.Nil(t, err) require.NotNil(t, rwd) baseRewardsPerNode := big.NewInt(1000000) - rwd.fillBaseRewardsPerBlockPerNode(baseRewardsPerNode) - consensusShard := args.NodesConfigProvider.ConsensusGroupSize(0) - consensusMeta := args.NodesConfigProvider.ConsensusGroupSize(core.MetachainShardId) + rwd.fillBaseRewardsPerBlockPerNode(baseRewardsPerNode, epoch) + consensusShard := args.NodesConfigProvider.ConsensusGroupSizeForShardAndEpoch(0, epoch) + consensusMeta := args.NodesConfigProvider.ConsensusGroupSizeForShardAndEpoch(core.MetachainShardId, epoch) expectedRewardPerNodeInShard := big.NewInt(0).Div(baseRewardsPerNode, big.NewInt(int64(consensusShard))) expectedRewardPerNodeInMeta := big.NewInt(0).Div(baseRewardsPerNode, big.NewInt(int64(consensusMeta))) @@ -1188,7 +1196,7 @@ func getBaseRewardsArguments() BaseRewardsCreatorArgs { Marshalizer: &mock.MarshalizerMock{}, DataPool: dataRetrieverMock.NewPoolsHolderMock(), NodesConfigProvider: &shardingMocks.NodesCoordinatorStub{ - ConsensusGroupSizeCalled: func(shardID uint32) int { + ConsensusGroupSizeCalled: func(shardID uint32, _ uint32) int { if shardID == core.MetachainShardId { return 400 } diff --git a/epochStart/metachain/rewards.go b/epochStart/metachain/rewards.go index 0b279d56c32..368a5bec809 100644 --- a/epochStart/metachain/rewards.go +++ b/epochStart/metachain/rewards.go @@ -77,7 +77,7 @@ func (rc *rewardsCreator) CreateRewardsMiniBlocks( return nil, err } - rc.fillBaseRewardsPerBlockPerNode(economicsData.GetRewardsPerBlock()) + rc.fillBaseRewardsPerBlockPerNode(economicsData.GetRewardsPerBlock(), metaBlock.GetEpoch()) err = rc.addValidatorRewardsToMiniBlocks(validatorsInfo, metaBlock, miniBlocks, protSustRwdTx) if err != nil { return nil, err diff --git a/epochStart/metachain/rewardsV2.go b/epochStart/metachain/rewardsV2.go index ac7b5074cc1..b72b3f751d7 100644 --- a/epochStart/metachain/rewardsV2.go +++ b/epochStart/metachain/rewardsV2.go @@ -290,7 +290,7 @@ func (rc *rewardsCreatorV2) computeRewardsPerNode( "baseRewards", baseRewards.String(), "topUpRewards", topUpRewards.String()) - rc.fillBaseRewardsPerBlockPerNode(baseRewardsPerBlock) + rc.fillBaseRewardsPerBlockPerNode(baseRewardsPerBlock, epoch) accumulatedDust := big.NewInt(0) dust := rc.computeBaseRewardsPerNode(nodesRewardInfo, baseRewards) diff --git a/epochStart/metachain/rewardsV2_test.go b/epochStart/metachain/rewardsV2_test.go index c024613ac25..000cf7085e0 100644 --- a/epochStart/metachain/rewardsV2_test.go +++ b/epochStart/metachain/rewardsV2_test.go @@ -708,7 +708,7 @@ func TestNewRewardsCreatorV2_computeBaseRewardsPerNode(t *testing.T) { for shardID := range shardMap { rwd.mapBaseRewardsPerBlockPerValidator[shardID] = big.NewInt(0).Set(baseRewardPerBlock) - cnsSize := big.NewInt(0).SetInt64(int64(args.NodesConfigProvider.ConsensusGroupSize(shardID))) + cnsSize := big.NewInt(0).SetInt64(int64(args.NodesConfigProvider.ConsensusGroupSizeForShardAndEpoch(shardID, 0))) rwd.mapBaseRewardsPerBlockPerValidator[shardID].Div(rwd.mapBaseRewardsPerBlockPerValidator[shardID], cnsSize) } @@ -1873,8 +1873,8 @@ func createDefaultValidatorInfo( proposerFeesPerNode uint32, nbBlocksPerShard uint32, ) state.ShardValidatorsInfoMapHandler { - cGrShard := uint32(nodesConfigProvider.ConsensusGroupSize(0)) - cGrMeta := uint32(nodesConfigProvider.ConsensusGroupSize(core.MetachainShardId)) + cGrShard := uint32(nodesConfigProvider.ConsensusGroupSizeForShardAndEpoch(0, 0)) + cGrMeta := uint32(nodesConfigProvider.ConsensusGroupSizeForShardAndEpoch(core.MetachainShardId, 0)) nbBlocksSelectedNodeInShard := nbBlocksPerShard * cGrShard / eligibleNodesPerShard nbBlocksSelectedNodeInMeta := nbBlocksPerShard * cGrMeta / eligibleNodesPerShard diff --git a/epochStart/metachain/rewards_test.go b/epochStart/metachain/rewards_test.go index 431e310ba9c..b53246271b2 100644 --- a/epochStart/metachain/rewards_test.go +++ b/epochStart/metachain/rewards_test.go @@ -585,7 +585,7 @@ func TestRewardsCreator_addValidatorRewardsToMiniBlocks(t *testing.T) { LeaderSuccess: 1, }) - rwdc.fillBaseRewardsPerBlockPerNode(mb.EpochStart.Economics.RewardsPerBlock) + rwdc.fillBaseRewardsPerBlockPerNode(mb.EpochStart.Economics.RewardsPerBlock, 0) err := rwdc.addValidatorRewardsToMiniBlocks(valInfo, mb, miniBlocks, &rewardTx.RewardTx{}) assert.Nil(t, err) assert.Equal(t, cloneMb, miniBlocks[0]) @@ -596,7 +596,7 @@ func TestRewardsCreator_ProtocolRewardsForValidatorFromMultipleShards(t *testing args := getRewardsArguments() args.NodesConfigProvider = &shardingMocks.NodesCoordinatorStub{ - ConsensusGroupSizeCalled: func(shardID uint32) int { + ConsensusGroupSizeCalled: func(shardID uint32, _ uint32) int { if shardID == core.MetachainShardId { return 400 } @@ -626,15 +626,15 @@ func TestRewardsCreator_ProtocolRewardsForValidatorFromMultipleShards(t *testing LeaderSuccess: 1, }) - rwdc.fillBaseRewardsPerBlockPerNode(mb.EpochStart.Economics.RewardsPerBlock) + rwdc.fillBaseRewardsPerBlockPerNode(mb.EpochStart.Economics.RewardsPerBlock, 0) rwdInfoData := rwdc.computeValidatorInfoPerRewardAddress(valInfo, &rewardTx.RewardTx{}, 0) assert.Equal(t, 1, len(rwdInfoData)) rwdInfo := rwdInfoData[pubkey] assert.Equal(t, rwdInfo.address, pubkey) assert.Equal(t, rwdInfo.accumulatedFees.Cmp(big.NewInt(200)), 0) - protocolRewards := uint64(valInfo.GetShardValidatorsInfoMap()[0][0].GetNumSelectedInSuccessBlocks()) * (mb.EpochStart.Economics.RewardsPerBlock.Uint64() / uint64(args.NodesConfigProvider.ConsensusGroupSize(0))) - protocolRewards += uint64(valInfo.GetShardValidatorsInfoMap()[core.MetachainShardId][0].GetNumSelectedInSuccessBlocks()) * (mb.EpochStart.Economics.RewardsPerBlock.Uint64() / uint64(args.NodesConfigProvider.ConsensusGroupSize(core.MetachainShardId))) + protocolRewards := uint64(valInfo.GetShardValidatorsInfoMap()[0][0].GetNumSelectedInSuccessBlocks()) * (mb.EpochStart.Economics.RewardsPerBlock.Uint64() / uint64(args.NodesConfigProvider.ConsensusGroupSizeForShardAndEpoch(0, 0))) + protocolRewards += uint64(valInfo.GetShardValidatorsInfoMap()[core.MetachainShardId][0].GetNumSelectedInSuccessBlocks()) * (mb.EpochStart.Economics.RewardsPerBlock.Uint64() / uint64(args.NodesConfigProvider.ConsensusGroupSizeForShardAndEpoch(core.MetachainShardId, 0))) assert.Equal(t, rwdInfo.rewardsFromProtocol.Uint64(), protocolRewards) } diff --git a/epochStart/metachain/systemSCs_test.go b/epochStart/metachain/systemSCs_test.go index ee3ea149cf4..bc8350bfcb3 100644 --- a/epochStart/metachain/systemSCs_test.go +++ b/epochStart/metachain/systemSCs_test.go @@ -966,7 +966,7 @@ func createFullArgumentsForSystemSCProcessing(enableEpochsConfig config.EnableEp StakingDataProvider: stakingSCProvider, AuctionListSelector: als, NodesConfigProvider: &shardingMocks.NodesCoordinatorStub{ - ConsensusGroupSizeCalled: func(shardID uint32) int { + ConsensusGroupSizeCalled: func(shardID uint32, _ uint32) int { if shardID == core.MetachainShardId { return 400 } diff --git a/epochStart/metachain/trigger.go b/epochStart/metachain/trigger.go index d0ccd0edebd..9d4855bb11e 100644 --- a/epochStart/metachain/trigger.go +++ b/epochStart/metachain/trigger.go @@ -15,13 +15,13 @@ import ( "github.com/multiversx/mx-chain-core-go/display" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-logger-go" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/storage" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("epochStart/metachain") diff --git a/epochStart/metachain/validators_test.go b/epochStart/metachain/validators_test.go index 662b0192044..2ece21d91d7 100644 --- a/epochStart/metachain/validators_test.go +++ b/epochStart/metachain/validators_test.go @@ -15,6 +15,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart" @@ -22,12 +25,11 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" vics "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockValidatorInfo() state.ShardValidatorsInfoMapHandler { @@ -128,7 +130,7 @@ func createMockEpochValidatorInfoCreatorsArguments() ArgsNewValidatorInfoCreator Marshalizer: &mock.MarshalizerMock{}, DataPool: &dataRetrieverMock.PoolsHolderStub{ MiniBlocksCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ RemoveCalled: func(key []byte) {}, } }, diff --git a/epochStart/mock/coreComponentsMock.go b/epochStart/mock/coreComponentsMock.go index b2f0003d842..e02642b3538 100644 --- a/epochStart/mock/coreComponentsMock.go +++ b/epochStart/mock/coreComponentsMock.go @@ -8,33 +8,39 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" ) // CoreComponentsMock - type CoreComponentsMock struct { - IntMarsh marshal.Marshalizer - Marsh marshal.Marshalizer - Hash hashing.Hasher - EpochNotifierField process.EpochNotifier - EnableEpochsHandlerField common.EnableEpochsHandler - TxSignHasherField hashing.Hasher - UInt64ByteSliceConv typeConverters.Uint64ByteSliceConverter - AddrPubKeyConv core.PubkeyConverter - ValPubKeyConv core.PubkeyConverter - PathHdl storage.PathManagerHandler - ChainIdCalled func() string - MinTransactionVersionCalled func() uint32 - GenesisNodesSetupCalled func() sharding.GenesisNodesSetupHandler - TxVersionCheckField process.TxVersionCheckerHandler - ChanStopNode chan endProcess.ArgEndProcess - NodeTypeProviderField core.NodeTypeProviderHandler - ProcessStatusHandlerInstance common.ProcessStatusHandler - HardforkTriggerPubKeyField []byte - mutCore sync.RWMutex + IntMarsh marshal.Marshalizer + Marsh marshal.Marshalizer + Hash hashing.Hasher + EpochNotifierField process.EpochNotifier + EnableEpochsHandlerField common.EnableEpochsHandler + TxSignHasherField hashing.Hasher + UInt64ByteSliceConv typeConverters.Uint64ByteSliceConverter + AddrPubKeyConv core.PubkeyConverter + ValPubKeyConv core.PubkeyConverter + PathHdl storage.PathManagerHandler + ChainIdCalled func() string + MinTransactionVersionCalled func() uint32 + GenesisNodesSetupCalled func() sharding.GenesisNodesSetupHandler + TxVersionCheckField process.TxVersionCheckerHandler + ChanStopNode chan endProcess.ArgEndProcess + NodeTypeProviderField core.NodeTypeProviderHandler + ProcessStatusHandlerInstance common.ProcessStatusHandler + HardforkTriggerPubKeyField []byte + ChainParametersHandlerField process.ChainParametersHandler + ChainParametersSubscriberField process.ChainParametersSubscriber + FieldsSizeCheckerField common.FieldsSizeChecker + EpochChangeGracePeriodHandlerField common.EpochChangeGracePeriodHandler + mutCore sync.RWMutex } // ChanStopNodeProcess - @@ -155,6 +161,30 @@ func (ccm *CoreComponentsMock) HardforkTriggerPubKey() []byte { return ccm.HardforkTriggerPubKeyField } +// ChainParametersHandler - +func (ccm *CoreComponentsMock) ChainParametersHandler() process.ChainParametersHandler { + if ccm.ChainParametersHandlerField != nil { + return ccm.ChainParametersHandlerField + } + + return &chainParameters.ChainParametersHolderMock{} +} + +// ChainParametersSubscriber - +func (ccm *CoreComponentsMock) ChainParametersSubscriber() process.ChainParametersSubscriber { + return ccm.ChainParametersSubscriberField +} + +// FieldsSizeChecker - +func (ccm *CoreComponentsMock) FieldsSizeChecker() common.FieldsSizeChecker { + return ccm.FieldsSizeCheckerField +} + +// EpochChangeGracePeriodHandler - +func (ccm *CoreComponentsMock) EpochChangeGracePeriodHandler() common.EpochChangeGracePeriodHandler { + return ccm.EpochChangeGracePeriodHandlerField +} + // IsInterfaceNil - func (ccm *CoreComponentsMock) IsInterfaceNil() bool { return ccm == nil diff --git a/epochStart/notifier/common.go b/epochStart/notifier/common.go index b535bd54589..bf36da4b45e 100644 --- a/epochStart/notifier/common.go +++ b/epochStart/notifier/common.go @@ -2,6 +2,7 @@ package notifier import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/epochStart/notifier/epochStartSubscriptionHandler.go b/epochStart/notifier/epochStartSubscriptionHandler.go index 1e4141a96dd..3d2041189ce 100644 --- a/epochStart/notifier/epochStartSubscriptionHandler.go +++ b/epochStart/notifier/epochStartSubscriptionHandler.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/epochStart/shardchain/peerMiniBlocksSyncer_test.go b/epochStart/shardchain/peerMiniBlocksSyncer_test.go index f58ef588a0d..3e131fa7074 100644 --- a/epochStart/shardchain/peerMiniBlocksSyncer_test.go +++ b/epochStart/shardchain/peerMiniBlocksSyncer_test.go @@ -9,18 +9,21 @@ import ( "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func createDefaultArguments() ArgPeerMiniBlockSyncer { defaultArgs := ArgPeerMiniBlockSyncer{ - MiniBlocksPool: testscommon.NewCacherStub(), + MiniBlocksPool: cache.NewCacherStub(), ValidatorsInfoPool: testscommon.NewShardedDataStub(), RequestHandler: &testscommon.RequestHandlerStub{}, } @@ -63,7 +66,7 @@ func TestNewValidatorInfoProcessor_NilRequestHandlerShouldErr(t *testing.T) { func TestValidatorInfoProcessor_IsInterfaceNil(t *testing.T) { args := createDefaultArguments() - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { }, } @@ -76,7 +79,7 @@ func TestValidatorInfoProcessor_IsInterfaceNil(t *testing.T) { func TestValidatorInfoProcessor_ShouldWork(t *testing.T) { args := createDefaultArguments() - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { }, } @@ -89,7 +92,7 @@ func TestValidatorInfoProcessor_ShouldWork(t *testing.T) { func TestValidatorInfoProcessor_ProcessMetaBlockThatIsNoStartOfEpochShouldWork(t *testing.T) { args := createDefaultArguments() - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { }, } @@ -104,7 +107,7 @@ func TestValidatorInfoProcessor_ProcessMetaBlockThatIsNoStartOfEpochShouldWork(t func TestValidatorInfoProcessor_ProcesStartOfEpochWithNoPeerMiniblocksShouldWork(t *testing.T) { args := createDefaultArguments() - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { }, } @@ -120,7 +123,7 @@ func TestValidatorInfoProcessor_ProcesStartOfEpochWithNoPeerMiniblocksShouldWork epochStartHeader.MiniBlockHeaders = []block.MiniBlockHeader{miniBlockHeader} peekCalled := false - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { }, @@ -182,7 +185,7 @@ func TestValidatorInfoProcessor_ProcesStartOfEpochWithPeerMiniblocksInPoolShould epochStartHeader.EpochStart.LastFinalizedHeaders = []block.EpochStartShardData{{ShardID: 0, RootHash: hash, HeaderHash: hash}} epochStartHeader.MiniBlockHeaders = []block.MiniBlockHeader{miniBlockHeader} - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { }, @@ -245,7 +248,7 @@ func TestValidatorInfoProcessor_ProcesStartOfEpochWithMissinPeerMiniblocksShould epochStartHeader.MiniBlockHeaders = []block.MiniBlockHeader{miniBlockHeader} var receivedMiniblock func(key []byte, value interface{}) - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { receivedMiniblock = f }, @@ -309,7 +312,7 @@ func TestValidatorInfoProcessor_ProcesStartOfEpochWithMissinPeerMiniblocksTimeou epochStartHeader.MiniBlockHeaders = []block.MiniBlockHeader{miniBlockHeader} var receivedMiniblock func(key []byte, value interface{}) - args.MiniBlocksPool = &testscommon.CacherStub{ + args.MiniBlocksPool = &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, value interface{})) { receivedMiniblock = f }, diff --git a/epochStart/shardchain/trigger.go b/epochStart/shardchain/trigger.go index 50b89a16bd0..f1de14fa7e1 100644 --- a/epochStart/shardchain/trigger.go +++ b/epochStart/shardchain/trigger.go @@ -79,6 +79,7 @@ type trigger struct { mapFinalizedEpochs map[uint32]string headersPool dataRetriever.HeadersPool + proofsPool dataRetriever.ProofsPool miniBlocksPool storage.Cacher validatorInfoPool dataRetriever.ShardedDataCacherNotifier currentEpochValidatorInfoPool epochStart.ValidatorInfoCacher @@ -170,6 +171,9 @@ func NewEpochStartTrigger(args *ArgsShardEpochStartTrigger) (*trigger, error) { if check.IfNil(args.DataPool.Headers()) { return nil, epochStart.ErrNilMetaBlocksPool } + if check.IfNil(args.DataPool.Proofs()) { + return nil, epochStart.ErrNilProofsPool + } if check.IfNil(args.DataPool.MiniBlocks()) { return nil, epochStart.ErrNilMiniBlockPool } @@ -247,6 +251,7 @@ func NewEpochStartTrigger(args *ArgsShardEpochStartTrigger) (*trigger, error) { mapEpochStartHdrs: make(map[string]data.HeaderHandler), mapFinalizedEpochs: make(map[uint32]string), headersPool: args.DataPool.Headers(), + proofsPool: args.DataPool.Proofs(), miniBlocksPool: args.DataPool.MiniBlocks(), validatorInfoPool: args.DataPool.ValidatorsInfo(), currentEpochValidatorInfoPool: args.DataPool.CurrentEpochValidatorInfo(), @@ -271,6 +276,7 @@ func NewEpochStartTrigger(args *ArgsShardEpochStartTrigger) (*trigger, error) { } t.headersPool.RegisterHandler(t.receivedMetaBlock) + t.proofsPool.RegisterHandler(t.receivedProof) err = t.saveState(t.triggerStateKey) if err != nil { @@ -555,12 +561,68 @@ func (t *trigger) changeEpochFinalityAttestingRoundIfNeeded( t.epochFinalityAttestingRound = metaHdr.GetRound() } +func (t *trigger) receivedProof(headerProof data.HeaderProofHandler) { + if check.IfNil(headerProof) { + return + } + if headerProof.GetHeaderShardId() != core.MetachainShardId { + return + } + + log.Debug("received proof in trigger", "proof for header hash", headerProof.GetHeaderHash()) + t.mutTrigger.Lock() + defer t.mutTrigger.Unlock() + + header, err := t.headersPool.GetHeaderByHash(headerProof.GetHeaderHash()) + if err != nil { + return + } + + t.checkMetaHeaderForEpochTriggerEquivalentProofs(header, headerProof.GetHeaderHash()) +} + // receivedMetaBlock is a callback function when a new metablock was received // upon receiving checks if trigger can be updated func (t *trigger) receivedMetaBlock(headerHandler data.HeaderHandler, metaBlockHash []byte) { + if headerHandler.GetShardID() != core.MetachainShardId { + return + } + + log.Debug("received meta header in trigger", "header hash", metaBlockHash) + if t.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, headerHandler.GetEpoch()) { + proof, err := t.proofsPool.GetProof(headerHandler.GetShardID(), metaBlockHash) + if err != nil { + return + } + + t.mutTrigger.Lock() + t.checkMetaHeaderForEpochTriggerEquivalentProofs(headerHandler, proof.GetHeaderHash()) + t.mutTrigger.Unlock() + return + } + t.mutTrigger.Lock() defer t.mutTrigger.Unlock() + t.checkMetaHeaderForEpochTriggerLegacy(headerHandler, metaBlockHash) +} + +func (t *trigger) checkMetaHeaderForEpochTriggerEquivalentProofs(headerHandler data.HeaderHandler, metaBlockHash []byte) { + metaHdr, ok := headerHandler.(*block.MetaBlock) + if !ok { + return + } + log.Debug("trigger.checkMetaHeaderForEpochTriggerEquivalentProofs", "metaHdr epoch", metaHdr.GetEpoch(), "metaBlockHash", metaBlockHash) + if !t.shouldUpdateTrigger(metaHdr, metaBlockHash) { + return + } + + log.Debug("trigger.updateTriggerHeaderData") + t.updateTriggerHeaderData(metaHdr, metaBlockHash) + t.updateTriggerFromMeta() +} + +func (t *trigger) checkMetaHeaderForEpochTriggerLegacy(headerHandler data.HeaderHandler, metaBlockHash []byte) { metaHdr, ok := headerHandler.(*block.MetaBlock) if !ok { return @@ -574,28 +636,41 @@ func (t *trigger) receivedMetaBlock(headerHandler data.HeaderHandler, metaBlockH } } - if !t.newEpochHdrReceived && !metaHdr.IsStartOfEpochBlock() { + if !t.shouldUpdateTrigger(metaHdr, metaBlockHash) { return } + t.updateTriggerHeaderData(metaHdr, metaBlockHash) + t.updateTriggerFromMeta() +} + +func (t *trigger) shouldUpdateTrigger(metaHdr *block.MetaBlock, metaBlockHash []byte) bool { + if !t.newEpochHdrReceived && !metaHdr.IsStartOfEpochBlock() { + return false + } + isMetaStartOfEpochForCurrentEpoch := metaHdr.Epoch == t.epoch && metaHdr.IsStartOfEpochBlock() if isMetaStartOfEpochForCurrentEpoch { - return + return false } - if _, ok = t.mapHashHdr[string(metaBlockHash)]; ok { - return + if _, ok := t.mapHashHdr[string(metaBlockHash)]; ok { + return false } - if _, ok = t.mapEpochStartHdrs[string(metaBlockHash)]; ok { - return + if _, ok := t.mapEpochStartHdrs[string(metaBlockHash)]; ok { + return false } + return true +} + +func (t *trigger) updateTriggerHeaderData(metaHdr *block.MetaBlock, metaBlockHash []byte) { if metaHdr.IsStartOfEpochBlock() { t.newEpochHdrReceived = true t.mapEpochStartHdrs[string(metaBlockHash)] = metaHdr // waiting for late broadcast of mini blocks and transactions to be done and received wait := t.extraDelayForRequestBlockInfo - roundDifferences := t.roundHandler.Index() - int64(headerHandler.GetRound()) + roundDifferences := t.roundHandler.Index() - int64(metaHdr.GetRound()) if roundDifferences > 1 { wait = 0 } @@ -605,8 +680,6 @@ func (t *trigger) receivedMetaBlock(headerHandler data.HeaderHandler, metaBlockH t.mapHashHdr[string(metaBlockHash)] = metaHdr t.mapNonceHashes[metaHdr.Nonce] = append(t.mapNonceHashes[metaHdr.Nonce], string(metaBlockHash)) - - t.updateTriggerFromMeta() } // call only if mutex is locked before @@ -650,6 +723,7 @@ func (t *trigger) updateTriggerFromMeta() { } canActivateEpochStart, finalityAttestingRound := t.checkIfTriggerCanBeActivated(currMetaInfo.hash, currMetaInfo.hdr) + log.Debug("trigger.updateTriggerFromMeta", "canActivateEpochStart", canActivateEpochStart, "finalityAttestingRound", finalityAttestingRound) if canActivateEpochStart && t.metaEpoch < currMetaInfo.hdr.GetEpoch() { t.metaEpoch = currMetaInfo.hdr.GetEpoch() t.isEpochStart = true @@ -722,10 +796,24 @@ func (t *trigger) isMetaBlockValid(hash string, metaHdr data.HeaderHandler) bool return true } -func (t *trigger) isMetaBlockFinal(_ string, metaHdr data.HeaderHandler) (bool, uint64) { +func (t *trigger) isMetaBlockFinal(hash string, metaHdr data.HeaderHandler) (bool, uint64) { + if !t.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, metaHdr.GetEpoch()) { + return t.isMetaBlockFinalLegacy(hash, metaHdr) + } + + hasProof := t.proofsPool.HasProof(metaHdr.GetShardID(), []byte(hash)) + if !hasProof { + return false, 0 + } + + return true, metaHdr.GetRound() +} + +func (t *trigger) isMetaBlockFinalLegacy(_ string, metaHdr data.HeaderHandler) (bool, uint64) { nextBlocksVerified := uint64(0) finalityAttestingRound := metaHdr.GetRound() currHdr := metaHdr + for nonce := metaHdr.GetNonce() + 1; nonce <= metaHdr.GetNonce()+t.finality; nonce++ { currHash, err := core.CalculateHash(t.marshaller, t.hasher, currHdr) if err != nil { diff --git a/epochStart/shardchain/triggerRegistry.go b/epochStart/shardchain/triggerRegistry.go index 899e99e83bc..d3f5e8d18c6 100644 --- a/epochStart/shardchain/triggerRegistry.go +++ b/epochStart/shardchain/triggerRegistry.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/epochStart/shardchain/triggerRegistry_test.go b/epochStart/shardchain/triggerRegistry_test.go index 5adccc849e1..970f48f6a73 100644 --- a/epochStart/shardchain/triggerRegistry_test.go +++ b/epochStart/shardchain/triggerRegistry_test.go @@ -6,13 +6,14 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func cloneTrigger(t *trigger) *trigger { @@ -42,6 +43,7 @@ func cloneTrigger(t *trigger) *trigger { rt.requestHandler = t.requestHandler rt.epochStartNotifier = t.epochStartNotifier rt.headersPool = t.headersPool + rt.proofsPool = t.proofsPool rt.epochStartShardHeader = t.epochStartShardHeader rt.epochStartMeta = t.epochStartMeta rt.shardHdrStorage = t.shardHdrStorage diff --git a/epochStart/shardchain/trigger_test.go b/epochStart/shardchain/trigger_test.go index 4da84e6c46c..be4c6ccb00e 100644 --- a/epochStart/shardchain/trigger_test.go +++ b/epochStart/shardchain/trigger_test.go @@ -12,20 +12,22 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/mock" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" vic "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockShardEpochStartTriggerArguments() *ArgsShardEpochStartTrigger { @@ -43,11 +45,14 @@ func createMockShardEpochStartTriggerArguments() *ArgsShardEpochStartTrigger { return &mock.HeadersCacherStub{} }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { return &vic.ValidatorInfoCacherStub{} }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, }, Storage: &storageStubs.ChainStorerStub{ GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { @@ -207,7 +212,7 @@ func TestNewEpochStartTrigger_NilHeadersPoolShouldErr(t *testing.T) { return nil }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, } epochStartTrigger, err := NewEpochStartTrigger(args) @@ -376,7 +381,7 @@ func TestTrigger_ReceivedHeaderIsEpochStartTrueWithPeerMiniblocks(t *testing.T) } }, MiniBlocksCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { if bytes.Equal(key, peerMiniBlockHash) { return peerMiniblock, true @@ -388,6 +393,9 @@ func TestTrigger_ReceivedHeaderIsEpochStartTrueWithPeerMiniblocks(t *testing.T) CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { return &vic.ValidatorInfoCacherStub{} }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } args.Uint64Converter = &mock.Uint64ByteSliceConverterMock{ ToByteSliceCalled: func(u uint64) []byte { @@ -718,7 +726,7 @@ func TestTrigger_UpdateMissingValidatorsInfo(t *testing.T) { return &mock.HeadersCacherStub{} }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { return &vic.ValidatorInfoCacherStub{} @@ -737,6 +745,9 @@ func TestTrigger_UpdateMissingValidatorsInfo(t *testing.T) { }, } }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } epochStartTrigger, _ := NewEpochStartTrigger(args) @@ -777,3 +788,178 @@ func TestTrigger_AddMissingValidatorsInfo(t *testing.T) { assert.Equal(t, uint32(1), epochStartTrigger.mapMissingValidatorsInfo["c"]) epochStartTrigger.mutMissingValidatorsInfo.RUnlock() } + +func TestTrigger_ReceivedProof(t *testing.T) { + t.Parallel() + + t.Run("early exits", func(t *testing.T) { + t.Parallel() + + args := createMockShardEpochStartTriggerArguments() + args.DataPool = &dataRetrieverMock.PoolsHolderStub{ + HeadersCalled: func() dataRetriever.HeadersPool { + return &mock.HeadersCacherStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + require.Fail(t, "should have not been called") + return nil, nil + }, + } + }, + } + epochStartTrigger, _ := NewEpochStartTrigger(args) + + // nil proof + epochStartTrigger.receivedProof(nil) + + epochStartTrigger.receivedProof(&block.HeaderProof{ + HeaderShardId: 0, // not meta + }) + }) + t.Run("GetHeaderByHash error should early exit", func(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected error") + args := createMockShardEpochStartTriggerArguments() + args.DataPool = &dataRetrieverMock.PoolsHolderStub{ + HeadersCalled: func() dataRetriever.HeadersPool { + return &mock.HeadersCacherStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return nil, expectedErr + }, + } + }, + MiniBlocksCalled: func() storage.Cacher { + return cache.NewCacherStub() + }, + CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { + return &vic.ValidatorInfoCacherStub{} + }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, + } + args.EpochStartNotifier = &mock.EpochStartNotifierStub{ + NotifyEpochChangeConfirmedCalled: func(epoch uint32) { + require.Fail(t, "should not have been called") + }, + } + epochStartTrigger, _ := NewEpochStartTrigger(args) + + epochStartTrigger.receivedProof(&block.HeaderProof{ + HeaderShardId: core.MetachainShardId, + }) + }) + t.Run("not meta block should exit", func(t *testing.T) { + t.Parallel() + + args := createMockShardEpochStartTriggerArguments() + args.DataPool = &dataRetrieverMock.PoolsHolderStub{ + HeadersCalled: func() dataRetriever.HeadersPool { + return &mock.HeadersCacherStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return &block.Header{}, nil + }, + } + }, + MiniBlocksCalled: func() storage.Cacher { + return cache.NewCacherStub() + }, + CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { + return &vic.ValidatorInfoCacherStub{} + }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, + } + args.EpochStartNotifier = &mock.EpochStartNotifierStub{ + NotifyEpochChangeConfirmedCalled: func(epoch uint32) { + require.Fail(t, "should not have been called") + }, + } + epochStartTrigger, _ := NewEpochStartTrigger(args) + + epochStartTrigger.receivedProof(&block.HeaderProof{ + HeaderShardId: core.MetachainShardId, + }) + }) + t.Run("should not update trigger should early exit", func(t *testing.T) { + t.Parallel() + + args := createMockShardEpochStartTriggerArguments() + args.DataPool = &dataRetrieverMock.PoolsHolderStub{ + HeadersCalled: func() dataRetriever.HeadersPool { + return &mock.HeadersCacherStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return &block.MetaBlock{}, nil + }, + } + }, + MiniBlocksCalled: func() storage.Cacher { + return cache.NewCacherStub() + }, + CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { + return &vic.ValidatorInfoCacherStub{} + }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, + } + args.EpochStartNotifier = &mock.EpochStartNotifierStub{ + NotifyEpochChangeConfirmedCalled: func(epoch uint32) { + require.Fail(t, "should not have been called") + }, + } + epochStartTrigger, _ := NewEpochStartTrigger(args) + + epochStartTrigger.receivedProof(&block.HeaderProof{ + HeaderShardId: core.MetachainShardId, + }) + }) + t.Run("should work and notify", func(t *testing.T) { + t.Parallel() + + args := createMockShardEpochStartTriggerArguments() + args.Validity = 2 + args.DataPool = &dataRetrieverMock.PoolsHolderStub{ + HeadersCalled: func() dataRetriever.HeadersPool { + return &mock.HeadersCacherStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return &block.MetaBlock{ + Epoch: 1, + Nonce: 3, + EpochStart: block.EpochStart{ + LastFinalizedHeaders: []block.EpochStartShardData{ + { + ShardID: 0, + }, + }, + }, + }, nil + }, + } + }, + MiniBlocksCalled: func() storage.Cacher { + return cache.NewCacherStub() + }, + CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { + return &vic.ValidatorInfoCacherStub{} + }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, + } + wasCalled := false + args.EpochStartNotifier = &mock.EpochStartNotifierStub{ + NotifyEpochChangeConfirmedCalled: func(epoch uint32) { + wasCalled = true + }, + } + epochStartTrigger, _ := NewEpochStartTrigger(args) + + epochStartTrigger.receivedProof(&block.HeaderProof{ + HeaderShardId: core.MetachainShardId, + }) + + require.True(t, wasCalled) + }) +} diff --git a/errors/errors.go b/errors/errors.go index dd475327876..7d2c26bff7a 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -475,6 +475,9 @@ var ErrNilESDTDataStorage = errors.New("nil esdt data storage") // ErrNilEnableEpochsHandler signals that a nil enable epochs handler was provided var ErrNilEnableEpochsHandler = errors.New("nil enable epochs handler") +// ErrNilChainParametersHandler signals that a nil chain parameters handler was provided +var ErrNilChainParametersHandler = errors.New("nil chain parameters handler") + // ErrSignerNotSupported signals that a not supported signer was provided var ErrSignerNotSupported = errors.New("signer not supported") @@ -598,3 +601,9 @@ var ErrNilSentSignatureTracker = errors.New("nil sent signature tracker") // ErrNilEpochSystemSCProcessor defines the error for setting a nil EpochSystemSCProcessor var ErrNilEpochSystemSCProcessor = errors.New("nil epoch system SC processor") + +// ErrNilFieldsSizeChecker signals tat a nil fields size checker has been provided +var ErrNilFieldsSizeChecker = errors.New("nil fields size checker") + +// ErrNilTrieLeavesRetriever defines the error for setting a nil TrieLeavesRetriever +var ErrNilTrieLeavesRetriever = errors.New("nil trie leaves retriever") diff --git a/facade/initial/initialNodeFacade.go b/facade/initial/initialNodeFacade.go index d6043dbcd62..ea9268d0bde 100644 --- a/facade/initial/initialNodeFacade.go +++ b/facade/initial/initialNodeFacade.go @@ -346,6 +346,11 @@ func (inf *initialNodeFacade) GetKeyValuePairs(_ string, _ api.AccountQueryOptio return nil, api.BlockInfo{}, errNodeStarting } +// IterateKeys returns error +func (inf *initialNodeFacade) IterateKeys(_ string, _ uint, _ [][]byte, _ api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) { + return nil, nil, api.BlockInfo{}, errNodeStarting +} + // GetGuardianData returns error func (inf *initialNodeFacade) GetGuardianData(_ string, _ api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) { return api.GuardianData{}, api.BlockInfo{}, errNodeStarting diff --git a/facade/interface.go b/facade/interface.go index 309f6c98d6f..2dfa8b503bd 100644 --- a/facade/interface.go +++ b/facade/interface.go @@ -41,6 +41,9 @@ type NodeHandler interface { // GetKeyValuePairs returns the key-value pairs under a given address GetKeyValuePairs(address string, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, error) + // IterateKeys returns the key-value pairs under a given address starting from a given state + IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions, ctx context.Context) (map[string]string, [][]byte, api.BlockInfo, error) + // GetAllIssuedESDTs returns all the issued esdt tokens from esdt system smart contract GetAllIssuedESDTs(tokenType string, ctx context.Context) ([]string, error) diff --git a/facade/mock/nodeStub.go b/facade/mock/nodeStub.go index 1e779e0ebce..e7b2817a32e 100644 --- a/facade/mock/nodeStub.go +++ b/facade/mock/nodeStub.go @@ -49,6 +49,7 @@ type NodeStub struct { GetESDTsWithRoleCalled func(address string, role string, options api.AccountQueryOptions, ctx context.Context) ([]string, api.BlockInfo, error) GetESDTsRolesCalled func(address string, options api.AccountQueryOptions, ctx context.Context) (map[string][]string, api.BlockInfo, error) GetKeyValuePairsCalled func(address string, options api.AccountQueryOptions, ctx context.Context) (map[string]string, api.BlockInfo, error) + IterateKeysCalled func(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions, ctx context.Context) (map[string]string, [][]byte, api.BlockInfo, error) GetAllIssuedESDTsCalled func(tokenType string, ctx context.Context) ([]string, error) GetProofCalled func(rootHash string, key string) (*common.GetProofResponse, error) GetProofDataTrieCalled func(rootHash string, address string, key string) (*common.GetProofResponse, *common.GetProofResponse, error) @@ -112,6 +113,15 @@ func (ns *NodeStub) GetKeyValuePairs(address string, options api.AccountQueryOpt return nil, api.BlockInfo{}, nil } +// IterateKeys - +func (ns *NodeStub) IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions, ctx context.Context) (map[string]string, [][]byte, api.BlockInfo, error) { + if ns.IterateKeysCalled != nil { + return ns.IterateKeysCalled(address, numKeys, iteratorState, options, ctx) + } + + return nil, nil, api.BlockInfo{}, nil +} + // GetValueForKey - func (ns *NodeStub) GetValueForKey(address string, key string, options api.AccountQueryOptions) (string, api.BlockInfo, error) { if ns.GetValueForKeyCalled != nil { diff --git a/facade/nodeFacade.go b/facade/nodeFacade.go index c3a7f290edf..e516b506b52 100644 --- a/facade/nodeFacade.go +++ b/facade/nodeFacade.go @@ -229,6 +229,14 @@ func (nf *nodeFacade) GetKeyValuePairs(address string, options apiData.AccountQu return nf.node.GetKeyValuePairs(address, options, ctx) } +// IterateKeys starts from the given iteratorState and returns the next key-value pairs and the new iteratorState +func (nf *nodeFacade) IterateKeys(address string, numKeys uint, iteratorState [][]byte, options apiData.AccountQueryOptions) (map[string]string, [][]byte, apiData.BlockInfo, error) { + ctx, cancel := nf.getContextForApiTrieRangeOperations() + defer cancel() + + return nf.node.IterateKeys(address, numKeys, iteratorState, options, ctx) +} + // GetGuardianData returns the guardian data for the provided address func (nf *nodeFacade) GetGuardianData(address string, options apiData.AccountQueryOptions) (apiData.GuardianData, apiData.BlockInfo, error) { return nf.node.GetGuardianData(address, options) diff --git a/factory/api/apiResolverFactory.go b/factory/api/apiResolverFactory.go index e7a7d761b32..1739154abe4 100644 --- a/factory/api/apiResolverFactory.go +++ b/factory/api/apiResolverFactory.go @@ -733,6 +733,8 @@ func createAPIBlockProcessorArgs(args *ApiResolverArgs, apiTransactionHandler ex AccountsRepository: args.StateComponents.AccountsRepository(), ScheduledTxsExecutionHandler: args.ProcessComponents.ScheduledTxsExecutionHandler(), EnableEpochsHandler: args.CoreComponents.EnableEpochsHandler(), + ProofsPool: args.DataComponents.Datapool().Proofs(), + BlockChain: args.DataComponents.Blockchain(), } return blockApiArgs, nil diff --git a/factory/block/headerVersionHandler_test.go b/factory/block/headerVersionHandler_test.go index 9de5238810b..4a17cb291a2 100644 --- a/factory/block/headerVersionHandler_test.go +++ b/factory/block/headerVersionHandler_test.go @@ -10,10 +10,12 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -50,7 +52,7 @@ func TestNewHeaderIntegrityVerifierr_InvalidVersionElementOnEpochValuesEqualShou }, }, defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) require.True(t, check.IfNil(hdrIntVer)) require.True(t, errors.Is(err, ErrInvalidVersionOnEpochValues)) @@ -67,7 +69,7 @@ func TestNewHeaderIntegrityVerifier_InvalidVersionElementOnStringTooLongShouldEr }, }, defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) require.True(t, check.IfNil(hdrIntVer)) require.True(t, errors.Is(err, ErrInvalidVersionStringTooLong)) @@ -79,7 +81,7 @@ func TestNewHeaderIntegrityVerifierr_InvalidDefaultVersionShouldErr(t *testing.T hdrIntVer, err := NewHeaderVersionHandler( versionsCorrectlyConstructed, "", - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) require.True(t, check.IfNil(hdrIntVer)) require.True(t, errors.Is(err, ErrInvalidSoftwareVersion)) @@ -103,7 +105,7 @@ func TestNewHeaderIntegrityVerifier_EmptyListShouldErr(t *testing.T) { hdrIntVer, err := NewHeaderVersionHandler( make([]config.VersionByEpochs, 0), defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) require.True(t, check.IfNil(hdrIntVer)) require.True(t, errors.Is(err, ErrEmptyVersionsByEpochsList)) @@ -120,7 +122,7 @@ func TestNewHeaderIntegrityVerifier_ZerothElementIsNotOnEpochZeroShouldErr(t *te }, }, defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) require.True(t, check.IfNil(hdrIntVer)) require.True(t, errors.Is(err, ErrInvalidVersionOnEpochValues)) @@ -132,7 +134,7 @@ func TestNewHeaderIntegrityVerifier_ShouldWork(t *testing.T) { hdrIntVer, err := NewHeaderVersionHandler( versionsCorrectlyConstructed, defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) require.False(t, check.IfNil(hdrIntVer)) require.NoError(t, err) @@ -147,7 +149,7 @@ func TestHeaderIntegrityVerifier_PopulatedReservedShouldErr(t *testing.T) { hdrIntVer, _ := NewHeaderVersionHandler( make([]config.VersionByEpochs, 0), defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) err := hdrIntVer.Verify(hdr) require.Equal(t, process.ErrReservedFieldInvalid, err) @@ -159,7 +161,7 @@ func TestHeaderIntegrityVerifier_VerifySoftwareVersionEmptyVersionInHeaderShould hdrIntVer, _ := NewHeaderVersionHandler( make([]config.VersionByEpochs, 0), defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) err := hdrIntVer.Verify(&block.MetaBlock{}) require.True(t, errors.Is(err, ErrInvalidSoftwareVersion)) @@ -180,7 +182,7 @@ func TestHeaderIntegrityVerifierr_VerifySoftwareVersionWrongVersionShouldErr(t * }, }, defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) err := hdrIntVer.Verify( &block.MetaBlock{ @@ -207,7 +209,7 @@ func TestHeaderIntegrityVerifier_VerifySoftwareVersionWildcardShouldWork(t *test }, }, defaultVersion, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) err := hdrIntVer.Verify( &block.MetaBlock{ @@ -227,7 +229,7 @@ func TestHeaderIntegrityVerifier_VerifyShouldWork(t *testing.T) { hdrIntVer, _ := NewHeaderVersionHandler( versionsCorrectlyConstructed, "software", - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) mb := &block.MetaBlock{ SoftwareVersion: []byte("software"), @@ -243,7 +245,7 @@ func TestHeaderIntegrityVerifier_VerifyNotWildcardShouldWork(t *testing.T) { hdrIntVer, _ := NewHeaderVersionHandler( versionsCorrectlyConstructed, "software", - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) mb := &block.MetaBlock{ SoftwareVersion: []byte("v1"), @@ -260,7 +262,7 @@ func TestHeaderIntegrityVerifier_GetVersionShouldWork(t *testing.T) { hdrIntVer, _ := NewHeaderVersionHandler( versionsCorrectlyConstructed, defaultVersion, - &testscommon.CacherStub{ + &cache.CacherStub{ PutCalled: func(key []byte, value interface{}, sizeInBytes int) bool { atomic.AddUint32(&numPutCalls, 1) epoch := binary.BigEndian.Uint32(key) @@ -311,7 +313,7 @@ func TestHeaderIntegrityVerifier_ExistsInInternalCacheShouldReturn(t *testing.T) hdrIntVer, _ := NewHeaderVersionHandler( versionsCorrectlyConstructed, defaultVersion, - &testscommon.CacherStub{ + &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return cachedVersion, true }, diff --git a/factory/bootstrap/bootstrapComponents.go b/factory/bootstrap/bootstrapComponents.go index a9ef7851ccb..2154289a4c7 100644 --- a/factory/bootstrap/bootstrapComponents.go +++ b/factory/bootstrap/bootstrapComponents.go @@ -3,9 +3,13 @@ package bootstrap import ( "fmt" "path/filepath" + "time" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + interceptorFactory "github.com/multiversx/mx-chain-go/process/interceptors/factory" + logger "github.com/multiversx/mx-chain-logger-go" + nodeFactory "github.com/multiversx/mx-chain-go/cmd/node/factory" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" @@ -24,23 +28,23 @@ import ( storageFactory "github.com/multiversx/mx-chain-go/storage/factory" "github.com/multiversx/mx-chain-go/storage/latestData" "github.com/multiversx/mx-chain-go/storage/storageunit" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("factory") // BootstrapComponentsFactoryArgs holds the arguments needed to create a bootstrap components factory type BootstrapComponentsFactoryArgs struct { - Config config.Config - RoundConfig config.RoundConfig - PrefConfig config.Preferences - ImportDbConfig config.ImportDbConfig - FlagsConfig config.ContextFlagsConfig - WorkingDir string - CoreComponents factory.CoreComponentsHolder - CryptoComponents factory.CryptoComponentsHolder - NetworkComponents factory.NetworkComponentsHolder - StatusCoreComponents factory.StatusCoreComponentsHolder + Config config.Config + RoundConfig config.RoundConfig + PrefConfig config.Preferences + ImportDbConfig config.ImportDbConfig + FlagsConfig config.ContextFlagsConfig + WorkingDir string + CoreComponents factory.CoreComponentsHolder + CryptoComponents factory.CryptoComponentsHolder + NetworkComponents factory.NetworkComponentsHolder + StatusCoreComponents factory.StatusCoreComponentsHolder + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory } type bootstrapComponentsFactory struct { @@ -198,6 +202,12 @@ func (bcf *bootstrapComponentsFactory) Create() (*bootstrapComponents, error) { return nil, err } + // create a new instance of interceptedDataVerifier which will be used for bootstrap only + interceptedDataVerifierFactory := interceptorFactory.NewInterceptedDataVerifierFactory(interceptorFactory.InterceptedDataVerifierFactoryArgs{ + CacheSpan: time.Duration(bcf.config.InterceptedDataVerifier.CacheSpanInSec) * time.Second, + CacheExpiry: time.Duration(bcf.config.InterceptedDataVerifier.CacheExpiryInSec) * time.Second, + }) + epochStartBootstrapArgs := bootstrap.ArgsEpochStartBootstrap{ CoreComponentsHolder: bcf.coreComponents, CryptoComponentsHolder: bcf.cryptoComponents, @@ -224,6 +234,8 @@ func (bcf *bootstrapComponentsFactory) Create() (*bootstrapComponents, error) { NodeProcessingMode: common.GetNodeProcessingMode(&bcf.importDbConfig), StateStatsHandler: bcf.statusCoreComponents.StateStatsHandler(), NodesCoordinatorRegistryFactory: nodesCoordinatorRegistryFactory, + EnableEpochsHandler: bcf.coreComponents.EnableEpochsHandler(), + InterceptedDataVerifierFactory: interceptedDataVerifierFactory, } var epochStartBootstrapper factory.EpochStartBootstrapper diff --git a/factory/bootstrap/shardingFactory.go b/factory/bootstrap/shardingFactory.go index 6662129299b..3c23df8fdaa 100644 --- a/factory/bootstrap/shardingFactory.go +++ b/factory/bootstrap/shardingFactory.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/epochStart" errErd "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/factory" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/storage" @@ -114,6 +115,7 @@ func CreateNodesCoordinator( enableEpochsHandler common.EnableEpochsHandler, validatorInfoCacher epochStart.ValidatorInfoCacher, nodesCoordinatorRegistryFactory nodesCoordinator.NodesCoordinatorRegistryFactory, + chainParametersHandler process.ChainParametersHandler, ) (nodesCoordinator.NodesCoordinator, error) { if check.IfNil(nodeShufflerOut) { return nil, errErd.ErrNilShuffleOutCloser @@ -148,8 +150,6 @@ func CreateNodesCoordinator( } nbShards := nodesConfig.NumberOfShards() - shardConsensusGroupSize := int(nodesConfig.GetShardConsensusGroupSize()) - metaConsensusGroupSize := int(nodesConfig.GetMetaConsensusGroupSize()) eligibleNodesInfo, waitingNodesInfo := nodesConfig.InitialNodesInfo() eligibleValidators, errEligibleValidators := nodesCoordinator.NodesInfoToValidators(eligibleNodesInfo) @@ -198,8 +198,7 @@ func CreateNodesCoordinator( } argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ - ShardConsensusGroupSize: shardConsensusGroupSize, - MetaConsensusGroupSize: metaConsensusGroupSize, + ChainParametersHandler: chainParametersHandler, Marshalizer: marshalizer, Hasher: hasher, Shuffler: nodeShuffler, diff --git a/factory/bootstrap/shardingFactory_test.go b/factory/bootstrap/shardingFactory_test.go index c7a54e077f4..4e6f118ab29 100644 --- a/factory/bootstrap/shardingFactory_test.go +++ b/factory/bootstrap/shardingFactory_test.go @@ -15,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" @@ -210,6 +211,7 @@ func TestCreateNodesCoordinator(t *testing.T) { &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + &chainParameters.ChainParametersHandlerStub{}, ) require.Equal(t, errErd.ErrNilShuffleOutCloser, err) require.True(t, check.IfNil(nodesC)) @@ -236,6 +238,7 @@ func TestCreateNodesCoordinator(t *testing.T) { &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + &chainParameters.ChainParametersHandlerStub{}, ) require.Equal(t, errErd.ErrNilGenesisNodesSetupHandler, err) require.True(t, check.IfNil(nodesC)) @@ -262,6 +265,7 @@ func TestCreateNodesCoordinator(t *testing.T) { &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + &chainParameters.ChainParametersHandlerStub{}, ) require.Equal(t, errErd.ErrNilEpochStartNotifier, err) require.True(t, check.IfNil(nodesC)) @@ -288,6 +292,7 @@ func TestCreateNodesCoordinator(t *testing.T) { &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + &chainParameters.ChainParametersHandlerStub{}, ) require.Equal(t, errErd.ErrNilPublicKey, err) require.True(t, check.IfNil(nodesC)) @@ -314,6 +319,7 @@ func TestCreateNodesCoordinator(t *testing.T) { &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + &chainParameters.ChainParametersHandlerStub{}, ) require.Equal(t, errErd.ErrNilBootstrapParamsHandler, err) require.True(t, check.IfNil(nodesC)) @@ -340,6 +346,7 @@ func TestCreateNodesCoordinator(t *testing.T) { &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + &chainParameters.ChainParametersHandlerStub{}, ) require.Equal(t, nodesCoordinator.ErrNilNodeStopChannel, err) require.True(t, check.IfNil(nodesC)) @@ -368,6 +375,7 @@ func TestCreateNodesCoordinator(t *testing.T) { &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + &chainParameters.ChainParametersHandlerStub{}, ) require.NotNil(t, err) require.True(t, check.IfNil(nodesC)) @@ -400,6 +408,7 @@ func TestCreateNodesCoordinator(t *testing.T) { &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + &chainParameters.ChainParametersHandlerStub{}, ) require.True(t, errors.Is(err, expectedErr)) require.True(t, check.IfNil(nodesC)) @@ -432,6 +441,7 @@ func TestCreateNodesCoordinator(t *testing.T) { &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + &chainParameters.ChainParametersHandlerStub{}, ) require.True(t, errors.Is(err, expectedErr)) require.True(t, check.IfNil(nodesC)) @@ -464,6 +474,7 @@ func TestCreateNodesCoordinator(t *testing.T) { &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + &chainParameters.ChainParametersHandlerStub{}, ) require.NotNil(t, err) require.True(t, check.IfNil(nodesC)) @@ -496,6 +507,7 @@ func TestCreateNodesCoordinator(t *testing.T) { &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + &chainParameters.ChainParametersHandlerStub{}, ) require.NotNil(t, err) require.True(t, check.IfNil(nodesC)) @@ -549,6 +561,7 @@ func TestCreateNodesCoordinator(t *testing.T) { &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + &chainParameters.ChainParametersHandlerStub{}, ) require.NotNil(t, err) require.True(t, check.IfNil(nodesC)) @@ -602,6 +615,7 @@ func TestCreateNodesCoordinator(t *testing.T) { &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, &validatorInfoCacherMocks.ValidatorInfoCacherStub{}, &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + &chainParameters.ChainParametersHandlerStub{}, ) require.Nil(t, err) require.False(t, check.IfNil(nodesC)) diff --git a/factory/consensus/consensusComponents.go b/factory/consensus/consensusComponents.go index decdb7c85fa..39efa2c4240 100644 --- a/factory/consensus/consensusComponents.go +++ b/factory/consensus/consensusComponents.go @@ -9,6 +9,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core/throttler" "github.com/multiversx/mx-chain-core-go/core/watchdog" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-storage-go/timecache" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/disabled" "github.com/multiversx/mx-chain-go/config" @@ -16,6 +19,7 @@ import ( "github.com/multiversx/mx-chain-go/consensus/blacklist" "github.com/multiversx/mx-chain-go/consensus/chronology" "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/consensus/spos/bls/proxy" "github.com/multiversx/mx-chain-go/consensus/spos/sposFactory" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/errors" @@ -25,17 +29,18 @@ import ( "github.com/multiversx/mx-chain-go/process/sync" "github.com/multiversx/mx-chain-go/process/sync/storageBootstrap" "github.com/multiversx/mx-chain-go/sharding" + nodesCoord "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state/syncer" "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/multiversx/mx-chain-go/update" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/multiversx/mx-chain-storage-go/timecache" ) var log = logger.GetOrCreate("factory") const defaultSpan = 300 * time.Second +const numSignatureGoRoutinesThrottler = 30 + // ConsensusComponentsFactoryArgs holds the arguments needed to create a consensus components factory type ConsensusComponentsFactoryArgs struct { Config config.Config @@ -78,7 +83,6 @@ type consensusComponents struct { worker factory.ConsensusWorker peerBlacklistHandler consensus.PeerBlacklistHandler consensusTopic string - consensusGroupSize int } // NewConsensusComponentsFactory creates an instance of consensusComponentsFactory @@ -112,13 +116,6 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { cc := &consensusComponents{} - consensusGroupSize, err := getConsensusGroupSize(ccf.coreComponents.GenesisNodesSetup(), ccf.processComponents.ShardCoordinator()) - if err != nil { - return nil, err - } - - cc.consensusGroupSize = int(consensusGroupSize) - blockchain := ccf.dataComponents.Blockchain() notInitializedGenesisBlock := len(blockchain.GetGenesisHeaderHash()) == 0 || check.IfNil(blockchain.GetGenesisHeader()) @@ -142,7 +139,12 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { } epoch := ccf.getEpoch() - consensusState, err := ccf.createConsensusState(epoch, cc.consensusGroupSize) + + consensusGroupSize, err := getConsensusGroupSize(ccf.coreComponents.GenesisNodesSetup(), ccf.processComponents.ShardCoordinator(), ccf.processComponents.NodesCoordinator(), epoch) + if err != nil { + return nil, err + } + consensusState, err := ccf.createConsensusState(epoch, consensusGroupSize) if err != nil { return nil, err } @@ -178,6 +180,21 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { return nil, err } + p2pSigningHandler, err := ccf.createP2pSigningHandler() + if err != nil { + return nil, err + } + + argsInvalidSignersCacher := spos.ArgInvalidSignersCache{ + Hasher: ccf.coreComponents.Hasher(), + SigningHandler: p2pSigningHandler, + Marshaller: ccf.coreComponents.InternalMarshalizer(), + } + invalidSignersCache, err := spos.NewInvalidSignersCache(argsInvalidSignersCacher) + if err != nil { + return nil, err + } + workerArgs := &spos.WorkerArgs{ ConsensusService: consensusService, BlockChain: ccf.dataComponents.Blockchain(), @@ -204,6 +221,8 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { AppStatusHandler: ccf.statusCoreComponents.AppStatusHandler(), NodeRedundancyHandler: ccf.processComponents.NodeRedundancyHandler(), PeerBlacklistHandler: cc.peerBlacklistHandler, + EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), + InvalidSignersCache: invalidSignersCache, } cc.worker, err = spos.NewWorker(workerArgs) @@ -214,21 +233,15 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { cc.worker.StartWorking() ccf.dataComponents.Datapool().Headers().RegisterHandler(cc.worker.ReceivedHeader) - // apply consensus group size on the input antiflooder just before consensus creation topic - ccf.networkComponents.InputAntiFloodHandler().ApplyConsensusSize( - ccf.processComponents.NodesCoordinator().ConsensusGroupSize( - ccf.processComponents.ShardCoordinator().SelfId()), + ccf.networkComponents.InputAntiFloodHandler().SetConsensusSizeNotifier( + ccf.coreComponents.ChainParametersSubscriber(), + ccf.processComponents.ShardCoordinator().SelfId(), ) err = ccf.createConsensusTopic(cc) if err != nil { return nil, err } - p2pSigningHandler, err := ccf.createP2pSigningHandler() - if err != nil { - return nil, err - } - consensusArgs := &spos.ConsensusCoreArgs{ BlockChain: ccf.dataComponents.Blockchain(), BlockProcessor: ccf.processComponents.BlockProcessor(), @@ -252,6 +265,10 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { MessageSigningHandler: p2pSigningHandler, PeerBlacklistHandler: cc.peerBlacklistHandler, SigningHandler: ccf.cryptoComponents.ConsensusSigningHandler(), + EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), + EquivalentProofsPool: ccf.dataComponents.Datapool().Proofs(), + EpochNotifier: ccf.coreComponents.EpochNotifier(), + InvalidSignersCache: invalidSignersCache, } consensusDataContainer, err := spos.NewConsensusCore( @@ -260,28 +277,34 @@ func (ccf *consensusComponentsFactory) Create() (*consensusComponents, error) { if err != nil { return nil, err } - - fct, err := sposFactory.GetSubroundsFactory( - consensusDataContainer, - consensusState, - cc.worker, - ccf.config.Consensus.Type, - ccf.statusCoreComponents.AppStatusHandler(), - ccf.statusComponents.OutportHandler(), - ccf.processComponents.SentSignaturesTracker(), - []byte(ccf.coreComponents.ChainID()), - ccf.networkComponents.NetworkMessenger().ID(), - ) + signatureThrottler, err := throttler.NewNumGoRoutinesThrottler(numSignatureGoRoutinesThrottler) if err != nil { return nil, err } - err = fct.GenerateSubrounds() + subroundsHandlerArgs := &proxy.SubroundsHandlerArgs{ + Chronology: cc.chronology, + ConsensusCoreHandler: consensusDataContainer, + ConsensusState: consensusState, + Worker: cc.worker, + SignatureThrottler: signatureThrottler, + AppStatusHandler: ccf.statusCoreComponents.AppStatusHandler(), + OutportHandler: ccf.statusComponents.OutportHandler(), + SentSignatureTracker: ccf.processComponents.SentSignaturesTracker(), + EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), + ChainID: []byte(ccf.coreComponents.ChainID()), + CurrentPid: ccf.networkComponents.NetworkMessenger().ID(), + } + + subroundsHandler, err := proxy.NewSubroundsHandler(subroundsHandlerArgs) if err != nil { return nil, err } - cc.chronology.StartRounds() + err = subroundsHandler.Start(epoch) + if err != nil { + return nil, err + } err = ccf.addCloserInstances(cc.chronology, cc.bootstrapper, cc.worker, ccf.coreComponents.SyncTimer()) if err != nil { @@ -434,6 +457,8 @@ func (ccf *consensusComponentsFactory) createShardBootstrapper() (process.Bootst EpochNotifier: ccf.coreComponents.EpochNotifier(), ProcessedMiniBlocksTracker: ccf.processComponents.ProcessedMiniBlocksTracker(), AppStatusHandler: ccf.statusCoreComponents.AppStatusHandler(), + EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), + ProofsPool: ccf.dataComponents.Datapool().Proofs(), } argsShardStorageBootstrapper := storageBootstrap.ArgsShardStorageBootstrapper{ @@ -488,6 +513,7 @@ func (ccf *consensusComponentsFactory) createShardBootstrapper() (process.Bootst ScheduledTxsExecutionHandler: ccf.processComponents.ScheduledTxsExecutionHandler(), ProcessWaitTime: time.Duration(ccf.config.GeneralSettings.SyncProcessTimeInMillis) * time.Millisecond, RepopulateTokensSupplies: ccf.flagsConfig.RepopulateTokensSupplies, + EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), } argsShardBootstrapper := sync.ArgShardBootstrapper{ @@ -567,6 +593,8 @@ func (ccf *consensusComponentsFactory) createMetaChainBootstrapper() (process.Bo EpochNotifier: ccf.coreComponents.EpochNotifier(), ProcessedMiniBlocksTracker: ccf.processComponents.ProcessedMiniBlocksTracker(), AppStatusHandler: ccf.statusCoreComponents.AppStatusHandler(), + EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), + ProofsPool: ccf.dataComponents.Datapool().Proofs(), } argsMetaStorageBootstrapper := storageBootstrap.ArgsMetaStorageBootstrapper{ @@ -618,6 +646,7 @@ func (ccf *consensusComponentsFactory) createMetaChainBootstrapper() (process.Bo ScheduledTxsExecutionHandler: ccf.processComponents.ScheduledTxsExecutionHandler(), ProcessWaitTime: time.Duration(ccf.config.GeneralSettings.SyncProcessTimeInMillis) * time.Millisecond, RepopulateTokensSupplies: ccf.flagsConfig.RepopulateTokensSupplies, + EnableEpochsHandler: ccf.coreComponents.EnableEpochsHandler(), } argsMetaBootstrapper := sync.ArgMetaBootstrapper{ @@ -752,12 +781,17 @@ func checkArgs(args ConsensusComponentsFactoryArgs) error { return nil } -func getConsensusGroupSize(nodesConfig sharding.GenesisNodesSetupHandler, shardCoordinator sharding.Coordinator) (uint32, error) { +func getConsensusGroupSize(nodesConfig sharding.GenesisNodesSetupHandler, shardCoordinator sharding.Coordinator, nodesCoordinator nodesCoord.NodesCoordinator, epoch uint32) (int, error) { + consensusGroupSize := nodesCoordinator.ConsensusGroupSizeForShardAndEpoch(shardCoordinator.SelfId(), epoch) + if consensusGroupSize > 0 { + return consensusGroupSize, nil + } + if shardCoordinator.SelfId() == core.MetachainShardId { - return nodesConfig.GetMetaConsensusGroupSize(), nil + return int(nodesConfig.GetMetaConsensusGroupSize()), nil } if shardCoordinator.SelfId() < shardCoordinator.NumberOfShards() { - return nodesConfig.GetShardConsensusGroupSize(), nil + return int(nodesConfig.GetShardConsensusGroupSize()), nil } return 0, sharding.ErrShardIdOutOfRange diff --git a/factory/consensus/consensusComponentsHandler.go b/factory/consensus/consensusComponentsHandler.go index 4e7779ab367..7fbaeb49381 100644 --- a/factory/consensus/consensusComponentsHandler.go +++ b/factory/consensus/consensusComponentsHandler.go @@ -101,18 +101,6 @@ func (mcc *managedConsensusComponents) BroadcastMessenger() consensus.BroadcastM return mcc.consensusComponents.broadcastMessenger } -// ConsensusGroupSize returns the consensus group size -func (mcc *managedConsensusComponents) ConsensusGroupSize() (int, error) { - mcc.mutConsensusComponents.RLock() - defer mcc.mutConsensusComponents.RUnlock() - - if mcc.consensusComponents == nil { - return 0, errors.ErrNilConsensusComponentsHolder - } - - return mcc.consensusComponents.consensusGroupSize, nil -} - // CheckSubcomponents verifies all subcomponents func (mcc *managedConsensusComponents) CheckSubcomponents() error { mcc.mutConsensusComponents.RLock() diff --git a/factory/consensus/consensusComponentsHandler_test.go b/factory/consensus/consensusComponentsHandler_test.go index c0a89f8a08e..ded44d2a837 100644 --- a/factory/consensus/consensusComponentsHandler_test.go +++ b/factory/consensus/consensusComponentsHandler_test.go @@ -74,24 +74,6 @@ func TestManagedConsensusComponents_Create(t *testing.T) { }) } -func TestManagedConsensusComponents_ConsensusGroupSize(t *testing.T) { - t.Parallel() - - consensusComponentsFactory, _ := consensusComp.NewConsensusComponentsFactory(createMockConsensusComponentsFactoryArgs()) - managedConsensusComponents, _ := consensusComp.NewManagedConsensusComponents(consensusComponentsFactory) - require.NotNil(t, managedConsensusComponents) - - size, err := managedConsensusComponents.ConsensusGroupSize() - require.Equal(t, errorsMx.ErrNilConsensusComponentsHolder, err) - require.Zero(t, size) - - err = managedConsensusComponents.Create() - require.NoError(t, err) - size, err = managedConsensusComponents.ConsensusGroupSize() - require.NoError(t, err) - require.Equal(t, 2, size) -} - func TestManagedConsensusComponents_CheckSubcomponents(t *testing.T) { t.Parallel() diff --git a/factory/consensus/consensusComponents_test.go b/factory/consensus/consensusComponents_test.go index c907bc84951..161c49e777f 100644 --- a/factory/consensus/consensusComponents_test.go +++ b/factory/consensus/consensusComponents_test.go @@ -9,6 +9,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" retriever "github.com/multiversx/mx-chain-go/dataRetriever" @@ -21,9 +23,11 @@ import ( "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" @@ -38,7 +42,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/testscommon/storageManager" "github.com/multiversx/mx-chain-go/update" - "github.com/stretchr/testify/require" ) func createMockConsensusComponentsFactoryArgs() consensusComp.ConsensusComponentsFactoryArgs { @@ -91,14 +94,17 @@ func createMockConsensusComponentsFactoryArgs() consensusComp.ConsensusComponent DataComponents: &testsMocks.DataComponentsStub{ DataPool: &dataRetriever.PoolsHolderStub{ MiniBlocksCalled: func() storage.Cacher { - return &testscommon.CacherStub{} + return &cache.CacherStub{} }, TrieNodesCalled: func() storage.Cacher { - return &testscommon.CacherStub{} + return &cache.CacherStub{} }, HeadersCalled: func() retriever.HeadersPool { return &testsMocks.HeadersCacherStub{} }, + ProofsCalled: func() retriever.ProofsPool { + return &dataRetrieverMocks.ProofsPoolMock{} + }, }, BlockChain: &testscommon.ChainHandlerStub{ GetGenesisHeaderHashCalled: func() []byte { @@ -114,7 +120,7 @@ func createMockConsensusComponentsFactoryArgs() consensusComp.ConsensusComponent ProcessComponents: &testsMocks.ProcessComponentsStub{ EpochTrigger: &testsMocks.EpochStartTriggerStub{}, EpochNotifier: &testsMocks.EpochStartNotifierStub{}, - NodesCoord: &shardingMocks.NodesCoordinatorStub{}, + NodesCoord: &shardingMocks.NodesCoordinatorMock{}, NodeRedundancyHandlerInternal: &testsMocks.RedundancyHandlerStub{}, HardforkTriggerField: &testscommon.HardforkTriggerStub{}, ReqHandler: &testscommon.RequestHandlerStub{}, @@ -137,7 +143,7 @@ func createMockConsensusComponentsFactoryArgs() consensusComp.ConsensusComponent CurrentEpochProviderInternal: &testsMocks.CurrentNetworkEpochProviderStub{}, HistoryRepositoryInternal: &dblookupext.HistoryRepositoryStub{}, IntContainer: &testscommon.InterceptorsContainerStub{}, - HeaderSigVerif: &testsMocks.HeaderSigVerifierStub{}, + HeaderSigVerif: &consensusMocks.HeaderSigVerifierMock{}, HeaderIntegrVerif: &mock.HeaderIntegrityVerifierStub{}, FallbackHdrValidator: &testscommon.FallBackHeaderValidatorStub{}, SentSignaturesTrackerInternal: &testscommon.SentSignatureTrackerStub{}, @@ -316,7 +322,7 @@ func TestNewConsensusComponentsFactory(t *testing.T) { args := createMockConsensusComponentsFactoryArgs() args.ProcessComponents = &testsMocks.ProcessComponentsStub{ - NodesCoord: &shardingMocks.NodesCoordinatorStub{}, + NodesCoord: &shardingMocks.NodesCoordinatorMock{}, ShardCoord: nil, } ccf, err := consensusComp.NewConsensusComponentsFactory(args) @@ -329,7 +335,7 @@ func TestNewConsensusComponentsFactory(t *testing.T) { args := createMockConsensusComponentsFactoryArgs() args.ProcessComponents = &testsMocks.ProcessComponentsStub{ - NodesCoord: &shardingMocks.NodesCoordinatorStub{}, + NodesCoord: &shardingMocks.NodesCoordinatorMock{}, ShardCoord: &testscommon.ShardsCoordinatorMock{}, RoundHandlerField: nil, } @@ -343,7 +349,7 @@ func TestNewConsensusComponentsFactory(t *testing.T) { args := createMockConsensusComponentsFactoryArgs() args.ProcessComponents = &testsMocks.ProcessComponentsStub{ - NodesCoord: &shardingMocks.NodesCoordinatorStub{}, + NodesCoord: &shardingMocks.NodesCoordinatorMock{}, ShardCoord: &testscommon.ShardsCoordinatorMock{}, RoundHandlerField: &testscommon.RoundHandlerMock{}, HardforkTriggerField: nil, @@ -498,7 +504,7 @@ func TestConsensusComponentsFactory_Create(t *testing.T) { cnt := 0 processCompStub.ShardCoordinatorCalled = func() sharding.Coordinator { cnt++ - if cnt > 2 { + if cnt > 1 { return nil // createBootstrapper fails } return testscommon.NewMultiShardsCoordinatorMock(2) @@ -520,7 +526,7 @@ func TestConsensusComponentsFactory_Create(t *testing.T) { shardC := testscommon.NewMultiShardsCoordinatorMock(2) processCompStub.ShardCoordinatorCalled = func() sharding.Coordinator { cnt++ - if cnt > 2 { + if cnt > 1 { shardC.SelfIDCalled = func() uint32 { return shardC.NoShards + 1 // createBootstrapper returns ErrShardIdOutOfRange } @@ -535,28 +541,6 @@ func TestConsensusComponentsFactory_Create(t *testing.T) { require.Equal(t, sharding.ErrShardIdOutOfRange, err) require.Nil(t, cc) }) - t.Run("createShardBootstrapper fails due to NewShardStorageBootstrapper failure should error", func(t *testing.T) { - t.Parallel() - - args := createMockConsensusComponentsFactoryArgs() - processCompStub, ok := args.ProcessComponents.(*testsMocks.ProcessComponentsStub) - require.True(t, ok) - cnt := 0 - processCompStub.ShardCoordinatorCalled = func() sharding.Coordinator { - cnt++ - if cnt > 3 { - return nil // NewShardStorageBootstrapper fails - } - return testscommon.NewMultiShardsCoordinatorMock(2) - } - ccf, _ := consensusComp.NewConsensusComponentsFactory(args) - require.NotNil(t, ccf) - - cc, err := ccf.Create() - require.Error(t, err) - require.True(t, strings.Contains(err.Error(), "shard coordinator")) - require.Nil(t, cc) - }) t.Run("createUserAccountsSyncer fails due to missing UserAccountTrie should error", func(t *testing.T) { t.Parallel() @@ -584,30 +568,6 @@ func TestConsensusComponentsFactory_Create(t *testing.T) { require.True(t, strings.Contains(err.Error(), "value is not positive")) require.Nil(t, cc) }) - t.Run("createMetaChainBootstrapper fails due to NewMetaStorageBootstrapper failure should error", func(t *testing.T) { - t.Parallel() - - args := createMockConsensusComponentsFactoryArgs() - processCompStub, ok := args.ProcessComponents.(*testsMocks.ProcessComponentsStub) - require.True(t, ok) - cnt := 0 - processCompStub.ShardCoordinatorCalled = func() sharding.Coordinator { - cnt++ - if cnt > 3 { - return nil // NewShardStorageBootstrapper fails - } - shardC := testscommon.NewMultiShardsCoordinatorMock(2) - shardC.CurrentShard = core.MetachainShardId - return shardC - } - ccf, _ := consensusComp.NewConsensusComponentsFactory(args) - require.NotNil(t, ccf) - - cc, err := ccf.Create() - require.Error(t, err) - require.True(t, strings.Contains(err.Error(), "shard coordinator")) - require.Nil(t, cc) - }) t.Run("createUserAccountsSyncer fails due to missing UserAccountTrie should error", func(t *testing.T) { t.Parallel() @@ -698,27 +658,6 @@ func TestConsensusComponentsFactory_Create(t *testing.T) { require.Equal(t, expectedErr, err) require.Nil(t, cc) }) - t.Run("createConsensusState fails due to nil nodes coordinator should error", func(t *testing.T) { - t.Parallel() - - args := createMockConsensusComponentsFactoryArgs() - processCompStub, ok := args.ProcessComponents.(*testsMocks.ProcessComponentsStub) - require.True(t, ok) - cnt := 0 - processCompStub.NodesCoordinatorCalled = func() nodesCoordinator.NodesCoordinator { - cnt++ - if cnt > 2 { - return nil - } - return &shardingMocks.NodesCoordinatorStub{} - } - ccf, _ := consensusComp.NewConsensusComponentsFactory(args) - require.NotNil(t, ccf) - - cc, err := ccf.Create() - require.Equal(t, errorsMx.ErrNilNodesCoordinator, err) - require.Nil(t, cc) - }) t.Run("createConsensusState fails due to GetConsensusWhitelistedNodes failure should error", func(t *testing.T) { t.Parallel() @@ -726,7 +665,7 @@ func TestConsensusComponentsFactory_Create(t *testing.T) { processCompStub, ok := args.ProcessComponents.(*testsMocks.ProcessComponentsStub) require.True(t, ok) processCompStub.NodesCoordinatorCalled = func() nodesCoordinator.NodesCoordinator { - return &shardingMocks.NodesCoordinatorStub{ + return &shardingMocks.NodesCoordinatorMock{ GetConsensusWhitelistedNodesCalled: func(epoch uint32) (map[string]struct{}, error) { return nil, expectedErr }, @@ -811,7 +750,7 @@ func TestConsensusComponentsFactory_Create(t *testing.T) { cnt := 0 processCompStub.ShardCoordinatorCalled = func() sharding.Coordinator { cnt++ - if cnt > 9 { + if cnt >= 10 { return nil // createConsensusTopic fails } return testscommon.NewMultiShardsCoordinatorMock(2) @@ -832,7 +771,7 @@ func TestConsensusComponentsFactory_Create(t *testing.T) { cnt := 0 netwCompStub.MessengerCalled = func() p2p.Messenger { cnt++ - if cnt > 3 { + if cnt > 4 { return nil } return &p2pmocks.MessengerStub{} @@ -902,28 +841,6 @@ func TestConsensusComponentsFactory_Create(t *testing.T) { require.True(t, strings.Contains(err.Error(), "signing handler")) require.Nil(t, cc) }) - t.Run("GetSubroundsFactory failure should error", func(t *testing.T) { - t.Parallel() - - args := createMockConsensusComponentsFactoryArgs() - statusCoreCompStub, ok := args.StatusCoreComponents.(*factoryMocks.StatusCoreComponentsStub) - require.True(t, ok) - cnt := 0 - statusCoreCompStub.AppStatusHandlerCalled = func() core.AppStatusHandler { - cnt++ - if cnt > 4 { - return nil - } - return &statusHandler.AppStatusHandlerStub{} - } - ccf, _ := consensusComp.NewConsensusComponentsFactory(args) - require.NotNil(t, ccf) - - cc, err := ccf.Create() - require.Error(t, err) - require.True(t, strings.Contains(err.Error(), "AppStatusHandler")) - require.Nil(t, cc) - }) t.Run("addCloserInstances failure should error", func(t *testing.T) { t.Parallel() diff --git a/factory/core/coreComponents.go b/factory/core/coreComponents.go index 1995aa43c72..1322a2c99d7 100644 --- a/factory/core/coreComponents.go +++ b/factory/core/coreComponents.go @@ -19,10 +19,15 @@ import ( hasherFactory "github.com/multiversx/mx-chain-core-go/hashing/factory" "github.com/multiversx/mx-chain-core-go/marshal" marshalizerFactory "github.com/multiversx/mx-chain-core-go/marshal/factory" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/chainparametersnotifier" "github.com/multiversx/mx-chain-go/common/enablers" commonFactory "github.com/multiversx/mx-chain-go/common/factory" + "github.com/multiversx/mx-chain-go/common/fieldsChecker" "github.com/multiversx/mx-chain-go/common/forking" + "github.com/multiversx/mx-chain-go/common/graceperiod" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/round" @@ -38,7 +43,6 @@ import ( "github.com/multiversx/mx-chain-go/statusHandler" "github.com/multiversx/mx-chain-go/storage" storageFactory "github.com/multiversx/mx-chain-go/storage/factory" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("factory") @@ -52,7 +56,7 @@ type CoreComponentsFactoryArgs struct { RatingsConfig config.RatingsConfig EconomicsConfig config.EconomicsConfig ImportDbConfig config.ImportDbConfig - NodesFilename string + NodesConfig config.NodesConfig WorkingDirectory string ChanStopNodeProcess chan endProcess.ArgEndProcess } @@ -66,7 +70,7 @@ type coreComponentsFactory struct { ratingsConfig config.RatingsConfig economicsConfig config.EconomicsConfig importDbConfig config.ImportDbConfig - nodesFilename string + nodesSetupConfig config.NodesConfig workingDir string chanStopNodeProcess chan endProcess.ArgEndProcess } @@ -98,6 +102,7 @@ type coreComponents struct { minTransactionVersion uint32 epochNotifier process.EpochNotifier roundNotifier process.RoundNotifier + chainParametersSubscriber process.ChainParametersSubscriber enableRoundsHandler process.EnableRoundsHandler epochStartNotifierWithConfirm factory.EpochStartNotifierWithConfirm chanStopNodeProcess chan endProcess.ArgEndProcess @@ -107,6 +112,9 @@ type coreComponents struct { processStatusHandler common.ProcessStatusHandler hardforkTriggerPubKey []byte enableEpochsHandler common.EnableEpochsHandler + chainParametersHandler process.ChainParametersHandler + fieldsSizeChecker common.FieldsSizeChecker + epochChangeGracePeriodHandler common.EpochChangeGracePeriodHandler } // NewCoreComponentsFactory initializes the factory which is responsible to creating core components @@ -121,7 +129,7 @@ func NewCoreComponentsFactory(args CoreComponentsFactoryArgs) (*coreComponentsFa economicsConfig: args.EconomicsConfig, workingDir: args.WorkingDirectory, chanStopNodeProcess: args.ChanStopNodeProcess, - nodesFilename: args.NodesFilename, + nodesSetupConfig: args.NodesConfig, }, nil } @@ -164,6 +172,11 @@ func (ccf *coreComponentsFactory) Create() (*coreComponents, error) { return nil, fmt.Errorf("%w for AddressPubkeyConverter", err) } + epochChangeGracePeriodHandler, err := graceperiod.NewEpochChangeGracePeriod(ccf.config.GeneralSettings.EpochChangeGracePeriodByEpoch) + if err != nil { + return nil, fmt.Errorf("%w for epochChangeGracePeriod", err) + } + pathHandler, err := storageFactory.CreatePathManager( storageFactory.ArgCreatePathManager{ WorkingDir: ccf.workingDir, @@ -178,8 +191,23 @@ func (ccf *coreComponentsFactory) Create() (*coreComponents, error) { syncer.StartSyncingTime() log.Debug("NTP average clock offset", "value", syncer.ClockOffset()) + epochNotifier := forking.NewGenericEpochNotifier() + epochStartHandlerWithConfirm := notifier.NewEpochStartSubscriptionHandler() + + chainParametersNotifier := chainparametersnotifier.NewChainParametersNotifier() + argsChainParametersHandler := sharding.ArgsChainParametersHolder{ + EpochStartEventNotifier: epochStartHandlerWithConfirm, + ChainParameters: ccf.config.GeneralSettings.ChainParametersByEpoch, + ChainParametersNotifier: chainParametersNotifier, + } + chainParametersHandler, err := sharding.NewChainParametersHolder(argsChainParametersHandler) + if err != nil { + return nil, err + } + genesisNodesConfig, err := sharding.NewNodesSetup( - ccf.nodesFilename, + ccf.nodesSetupConfig, + chainParametersHandler, addressPubkeyConverter, validatorPubkeyConverter, ccf.config.GeneralSettings.GenesisMaxNumberOfShards, @@ -209,8 +237,6 @@ func (ccf *coreComponentsFactory) Create() (*coreComponents, error) { "formatted", startTime.Format("Mon Jan 2 15:04:05 MST 2006"), "seconds", startTime.Unix()) - log.Debug("config", "file", ccf.nodesFilename) - genesisTime := time.Unix(genesisNodesConfig.StartTime, 0) roundHandler, err := round.NewRound( genesisTime, @@ -236,7 +262,6 @@ func (ccf *coreComponentsFactory) Create() (*coreComponents, error) { return nil, err } - epochNotifier := forking.NewGenericEpochNotifier() enableEpochsHandler, err := enablers.NewEnableEpochsHandler(ccf.epochConfig.EnableEpochs, epochNotifier) if err != nil { return nil, err @@ -276,12 +301,10 @@ func (ccf *coreComponentsFactory) Create() (*coreComponents, error) { log.Trace("creating ratings data") ratingDataArgs := rating.RatingsDataArg{ - Config: ccf.ratingsConfig, - ShardConsensusSize: genesisNodesConfig.ConsensusGroupSize, - MetaConsensusSize: genesisNodesConfig.MetaChainConsensusGroupSize, - ShardMinNodes: genesisNodesConfig.MinNodesPerShard, - MetaMinNodes: genesisNodesConfig.MetaChainMinNodes, - RoundDurationMiliseconds: genesisNodesConfig.RoundDuration, + Config: ccf.ratingsConfig, + ChainParametersHolder: chainParametersHandler, + RoundDurationMilliseconds: genesisNodesConfig.RoundDuration, + EpochNotifier: epochNotifier, } ratingsData, err := rating.NewRatingsData(ratingDataArgs) if err != nil { @@ -294,10 +317,6 @@ func (ccf *coreComponentsFactory) Create() (*coreComponents, error) { } argsNodesShuffler := &nodesCoordinator.NodesShufflerArgs{ - NodesShard: genesisNodesConfig.MinNumberOfShardNodes(), - NodesMeta: genesisNodesConfig.MinNumberOfMetaNodes(), - Hysteresis: genesisNodesConfig.GetHysteresis(), - Adaptivity: genesisNodesConfig.GetAdaptivity(), ShuffleBetweenShards: true, MaxNodesEnableConfig: ccf.epochConfig.EnableEpochs.MaxNodesChangeEnableEpoch, EnableEpochsHandler: enableEpochsHandler, @@ -323,6 +342,11 @@ func (ccf *coreComponentsFactory) Create() (*coreComponents, error) { return nil, err } + fieldsSizeChecker, err := fieldsChecker.NewFieldsSizeChecker(chainParametersHandler, hasher) + if err != nil { + return nil, err + } + return &coreComponents{ hasher: hasher, txSignHasher: txSignHasher, @@ -349,8 +373,9 @@ func (ccf *coreComponentsFactory) Create() (*coreComponents, error) { minTransactionVersion: ccf.config.GeneralSettings.MinTransactionVersion, epochNotifier: epochNotifier, roundNotifier: roundNotifier, + chainParametersSubscriber: chainParametersNotifier, enableRoundsHandler: enableRoundsHandler, - epochStartNotifierWithConfirm: notifier.NewEpochStartSubscriptionHandler(), + epochStartNotifierWithConfirm: epochStartHandlerWithConfirm, chanStopNodeProcess: ccf.chanStopNodeProcess, encodedAddressLen: encodedAddressLen, nodeTypeProvider: nodeTypeProvider, @@ -358,6 +383,9 @@ func (ccf *coreComponentsFactory) Create() (*coreComponents, error) { processStatusHandler: statusHandler.NewProcessStatusHandler(), hardforkTriggerPubKey: pubKeyBytes, enableEpochsHandler: enableEpochsHandler, + chainParametersHandler: chainParametersHandler, + fieldsSizeChecker: fieldsSizeChecker, + epochChangeGracePeriodHandler: epochChangeGracePeriodHandler, }, nil } diff --git a/factory/core/coreComponentsHandler.go b/factory/core/coreComponentsHandler.go index b10c378023e..13b1735ad36 100644 --- a/factory/core/coreComponentsHandler.go +++ b/factory/core/coreComponentsHandler.go @@ -11,6 +11,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/errors" @@ -149,6 +150,12 @@ func (mcc *managedCoreComponents) CheckSubcomponents() error { if check.IfNil(mcc.enableEpochsHandler) { return errors.ErrNilEnableEpochsHandler } + if check.IfNil(mcc.chainParametersHandler) { + return errors.ErrNilChainParametersHandler + } + if check.IfNil(mcc.fieldsSizeChecker) { + return errors.ErrNilFieldsSizeChecker + } if len(mcc.chainID) == 0 { return errors.ErrInvalidChainID } @@ -485,6 +492,18 @@ func (mcc *managedCoreComponents) RoundNotifier() process.RoundNotifier { return mcc.coreComponents.roundNotifier } +// ChainParametersSubscriber returns the chain parameters subscriber +func (mcc *managedCoreComponents) ChainParametersSubscriber() process.ChainParametersSubscriber { + mcc.mutCoreComponents.RLock() + defer mcc.mutCoreComponents.RUnlock() + + if mcc.coreComponents == nil { + return nil + } + + return mcc.coreComponents.chainParametersSubscriber +} + // EnableRoundsHandler returns the rounds activation handler func (mcc *managedCoreComponents) EnableRoundsHandler() process.EnableRoundsHandler { mcc.mutCoreComponents.RLock() @@ -581,6 +600,42 @@ func (mcc *managedCoreComponents) EnableEpochsHandler() common.EnableEpochsHandl return mcc.coreComponents.enableEpochsHandler } +// ChainParametersHandler returns the chain parameters handler +func (mcc *managedCoreComponents) ChainParametersHandler() process.ChainParametersHandler { + mcc.mutCoreComponents.RLock() + defer mcc.mutCoreComponents.RUnlock() + + if mcc.coreComponents == nil { + return nil + } + + return mcc.coreComponents.chainParametersHandler +} + +// FieldsSizeChecker returns the fields size checker component +func (mcc *managedCoreComponents) FieldsSizeChecker() common.FieldsSizeChecker { + mcc.mutCoreComponents.RLock() + defer mcc.mutCoreComponents.RUnlock() + + if mcc.coreComponents == nil { + return nil + } + + return mcc.coreComponents.fieldsSizeChecker +} + +// EpochChangeGracePeriodHandler returns the epoch change grace period handler component +func (mcc *managedCoreComponents) EpochChangeGracePeriodHandler() common.EpochChangeGracePeriodHandler { + mcc.mutCoreComponents.RLock() + defer mcc.mutCoreComponents.RUnlock() + + if mcc.coreComponents == nil { + return nil + } + + return mcc.coreComponents.epochChangeGracePeriodHandler +} + // IsInterfaceNil returns true if there is no value under the interface func (mcc *managedCoreComponents) IsInterfaceNil() bool { return mcc == nil diff --git a/factory/heartbeat/heartbeatV2Components_test.go b/factory/heartbeat/heartbeatV2Components_test.go index 9a0eb3b14e3..f605bc67b9c 100644 --- a/factory/heartbeat/heartbeatV2Components_test.go +++ b/factory/heartbeat/heartbeatV2Components_test.go @@ -6,6 +6,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" errorsMx "github.com/multiversx/mx-chain-go/errors" @@ -14,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks" + "github.com/multiversx/mx-chain-go/testscommon/cache" componentsMock "github.com/multiversx/mx-chain-go/testscommon/components" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" @@ -23,7 +26,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) func createMockHeartbeatV2ComponentsFactoryArgs() heartbeatComp.ArgHeartbeatV2ComponentsFactory { @@ -54,10 +56,10 @@ func createMockHeartbeatV2ComponentsFactoryArgs() heartbeatComp.ArgHeartbeatV2Co DataComponents: &testsMocks.DataComponentsStub{ DataPool: &dataRetriever.PoolsHolderStub{ PeerAuthenticationsCalled: func() storage.Cacher { - return &testscommon.CacherStub{} + return &cache.CacherStub{} }, HeartbeatsCalled: func() storage.Cacher { - return &testscommon.CacherStub{} + return &cache.CacherStub{} }, }, BlockChain: &testscommon.ChainHandlerStub{}, diff --git a/factory/interface.go b/factory/interface.go index f30fce3784b..cf526c2e6f2 100644 --- a/factory/interface.go +++ b/factory/interface.go @@ -14,6 +14,8 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/multiversx/mx-chain-go/cmd/node/factory" "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" @@ -37,7 +39,6 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/update" "github.com/multiversx/mx-chain-go/vm" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) // EpochStartNotifier defines which actions should be done for handling new epoch's events @@ -66,7 +67,7 @@ type P2PAntifloodHandler interface { SetDebugger(debugger process.AntifloodDebugger) error SetPeerValidatorMapper(validatorMapper process.PeerValidatorMapper) error SetTopicsForAll(topics ...string) - ApplyConsensusSize(size int) + SetConsensusSizeNotifier(chainParametersNotifier process.ChainParametersSubscriber, shardID uint32) BlacklistPeer(peer core.PeerID, reason string, duration time.Duration) IsOriginatorEligibleForTopic(pid core.PeerID, topic string) error Close() error @@ -120,6 +121,7 @@ type CoreComponentsHolder interface { GenesisNodesSetup() sharding.GenesisNodesSetupHandler NodesShuffler() nodesCoordinator.NodesShuffler EpochNotifier() process.EpochNotifier + ChainParametersSubscriber() process.ChainParametersSubscriber EnableRoundsHandler() process.EnableRoundsHandler RoundNotifier() process.RoundNotifier EpochStartNotifierWithConfirm() EpochStartNotifierWithConfirm @@ -134,6 +136,9 @@ type CoreComponentsHolder interface { ProcessStatusHandler() common.ProcessStatusHandler HardforkTriggerPubKey() []byte EnableEpochsHandler() common.EnableEpochsHandler + ChainParametersHandler() process.ChainParametersHandler + FieldsSizeChecker() common.FieldsSizeChecker + EpochChangeGracePeriodHandler() common.EpochChangeGracePeriodHandler IsInterfaceNil() bool } @@ -336,6 +341,7 @@ type StateComponentsHolder interface { TriesContainer() common.TriesHolder TrieStorageManagers() map[string]common.StorageManager MissingTrieNodesNotifier() common.MissingTrieNodesNotifier + TrieLeavesRetriever() common.TrieLeavesRetriever Close() error IsInterfaceNil() bool } @@ -384,10 +390,14 @@ type ConsensusWorker interface { AddReceivedMessageCall(messageType consensus.MessageType, receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool) // AddReceivedHeaderHandler adds a new handler function for a received header AddReceivedHeaderHandler(handler func(data.HeaderHandler)) + // RemoveAllReceivedHeaderHandlers removes all the functions handlers + RemoveAllReceivedHeaderHandlers() + // AddReceivedProofHandler adds a new handler function for a received proof + AddReceivedProofHandler(handler func(proofHandler consensus.ProofHandler)) // RemoveAllReceivedMessagesCalls removes all the functions handlers RemoveAllReceivedMessagesCalls() // ProcessReceivedMessage method redirects the received message to the channel which should handle it - ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error + ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) // Extend does an extension for the subround with subroundId Extend(subroundId int) // GetConsensusStateChangedChannel gets the channel for the consensusStateChanged @@ -396,10 +406,16 @@ type ConsensusWorker interface { ExecuteStoredMessages() // DisplayStatistics method displays statistics of worker at the end of the round DisplayStatistics() - // ResetConsensusMessages resets at the start of each round all the previous consensus messages received + // ResetConsensusMessages resets at the start of each round all the previous consensus messages received and equivalent messages, keeping the provided proofs ResetConsensusMessages() + // ResetConsensusRoundState resets the state of the consensus round + ResetConsensusRoundState() + // ResetInvalidSignersCache resets the invalid signers cache + ResetInvalidSignersCache() // ReceivedHeader method is a wired method through which worker will receive headers from network ReceivedHeader(headerHandler data.HeaderHandler, headerHash []byte) + // ReceivedProof will handle a received proof in consensus worker + ReceivedProof(proofHandler consensus.ProofHandler) // IsInterfaceNil returns true if there is no value under the interface IsInterfaceNil() bool } @@ -422,7 +438,6 @@ type ConsensusComponentsHolder interface { Chronology() consensus.ChronologyHandler ConsensusWorker() ConsensusWorker BroadcastMessenger() consensus.BroadcastMessenger - ConsensusGroupSize() (int, error) Bootstrapper() process.Bootstrapper IsInterfaceNil() bool } diff --git a/factory/mock/coreComponentsMock.go b/factory/mock/coreComponentsMock.go index 0393f44c4a1..6d38ca5208d 100644 --- a/factory/mock/coreComponentsMock.go +++ b/factory/mock/coreComponentsMock.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/factory" @@ -21,41 +22,45 @@ import ( // CoreComponentsMock - type CoreComponentsMock struct { - IntMarsh marshal.Marshalizer - TxMarsh marshal.Marshalizer - VmMarsh marshal.Marshalizer - Hash hashing.Hasher - TxSignHasherField hashing.Hasher - UInt64ByteSliceConv typeConverters.Uint64ByteSliceConverter - AddrPubKeyConv core.PubkeyConverter - ValPubKeyConv core.PubkeyConverter - PathHdl storage.PathManagerHandler - WatchdogTimer core.WatchdogTimer - AlarmSch core.TimersScheduler - NtpSyncTimer ntp.SyncTimer - GenesisBlockTime time.Time - ChainIdCalled func() string - MinTransactionVersionCalled func() uint32 - mutIntMarshalizer sync.RWMutex - RoundHandlerField consensus.RoundHandler - EconomicsHandler process.EconomicsDataHandler - APIEconomicsHandler process.EconomicsDataHandler - RatingsConfig process.RatingsInfoHandler - RatingHandler sharding.PeerAccountListAndRatingHandler - NodesConfig sharding.GenesisNodesSetupHandler - Shuffler nodesCoordinator.NodesShuffler - EpochChangeNotifier process.EpochNotifier - RoundChangeNotifier process.RoundNotifier - EnableRoundsHandlerField process.EnableRoundsHandler - EpochNotifierWithConfirm factory.EpochStartNotifierWithConfirm - TxVersionCheckHandler process.TxVersionCheckerHandler - ChanStopProcess chan endProcess.ArgEndProcess - StartTime time.Time - NodeTypeProviderField core.NodeTypeProviderHandler - WasmVMChangeLockerInternal common.Locker - ProcessStatusHandlerInternal common.ProcessStatusHandler - HardforkTriggerPubKeyField []byte - EnableEpochsHandlerField common.EnableEpochsHandler + IntMarsh marshal.Marshalizer + TxMarsh marshal.Marshalizer + VmMarsh marshal.Marshalizer + Hash hashing.Hasher + TxSignHasherField hashing.Hasher + UInt64ByteSliceConv typeConverters.Uint64ByteSliceConverter + AddrPubKeyConv core.PubkeyConverter + ValPubKeyConv core.PubkeyConverter + PathHdl storage.PathManagerHandler + WatchdogTimer core.WatchdogTimer + AlarmSch core.TimersScheduler + NtpSyncTimer ntp.SyncTimer + GenesisBlockTime time.Time + ChainIdCalled func() string + MinTransactionVersionCalled func() uint32 + mutIntMarshalizer sync.RWMutex + RoundHandlerField consensus.RoundHandler + EconomicsHandler process.EconomicsDataHandler + APIEconomicsHandler process.EconomicsDataHandler + RatingsConfig process.RatingsInfoHandler + RatingHandler sharding.PeerAccountListAndRatingHandler + NodesConfig sharding.GenesisNodesSetupHandler + Shuffler nodesCoordinator.NodesShuffler + EpochChangeNotifier process.EpochNotifier + RoundChangeNotifier process.RoundNotifier + EnableRoundsHandlerField process.EnableRoundsHandler + EpochNotifierWithConfirm factory.EpochStartNotifierWithConfirm + TxVersionCheckHandler process.TxVersionCheckerHandler + ChanStopProcess chan endProcess.ArgEndProcess + StartTime time.Time + NodeTypeProviderField core.NodeTypeProviderHandler + WasmVMChangeLockerInternal common.Locker + ProcessStatusHandlerInternal common.ProcessStatusHandler + HardforkTriggerPubKeyField []byte + EnableEpochsHandlerField common.EnableEpochsHandler + ChainParametersHandlerField process.ChainParametersHandler + ChainParametersSubscriberField process.ChainParametersSubscriber + FieldsSizeCheckerField common.FieldsSizeChecker + EpochChangeGracePeriodHandlerField common.EpochChangeGracePeriodHandler } // InternalMarshalizer - @@ -246,6 +251,26 @@ func (ccm *CoreComponentsMock) EnableEpochsHandler() common.EnableEpochsHandler return ccm.EnableEpochsHandlerField } +// ChainParametersHandler - +func (ccm *CoreComponentsMock) ChainParametersHandler() process.ChainParametersHandler { + return ccm.ChainParametersHandlerField +} + +// ChainParametersSubscriber - +func (ccm *CoreComponentsMock) ChainParametersSubscriber() process.ChainParametersSubscriber { + return ccm.ChainParametersSubscriberField +} + +// FieldsSizeChecker - +func (ccm *CoreComponentsMock) FieldsSizeChecker() common.FieldsSizeChecker { + return ccm.FieldsSizeCheckerField +} + +// EpochChangeGracePeriodHandler - +func (ccm *CoreComponentsMock) EpochChangeGracePeriodHandler() common.EpochChangeGracePeriodHandler { + return ccm.EpochChangeGracePeriodHandlerField +} + // IsInterfaceNil - func (ccm *CoreComponentsMock) IsInterfaceNil() bool { return ccm == nil diff --git a/factory/mock/epochStartNotifierStub.go b/factory/mock/epochStartNotifierStub.go index 128242e1203..7e29fbae327 100644 --- a/factory/mock/epochStartNotifierStub.go +++ b/factory/mock/epochStartNotifierStub.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/factory/mock/forkDetectorMock.go b/factory/mock/forkDetectorMock.go index 4a041bc814a..217ee15e141 100644 --- a/factory/mock/forkDetectorMock.go +++ b/factory/mock/forkDetectorMock.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" ) @@ -19,6 +20,8 @@ type ForkDetectorMock struct { RestoreToGenesisCalled func() ResetProbableHighestNonceCalled func() SetFinalToLastCheckpointCalled func() + ReceivedProofCalled func(proof data.HeaderProofHandler) + AddCheckpointCalled func(nonce uint64, round uint64, hash []byte) } // RestoreToGenesis - @@ -111,6 +114,20 @@ func (fdm *ForkDetectorMock) SetFinalToLastCheckpoint() { } } +// ReceivedProof - +func (fdm *ForkDetectorMock) ReceivedProof(proof data.HeaderProofHandler) { + if fdm.ReceivedProofCalled != nil { + fdm.ReceivedProofCalled(proof) + } +} + +// AddCheckpoint - +func (fdm *ForkDetectorMock) AddCheckpoint(nonce uint64, round uint64, hash []byte) { + if fdm.AddCheckpointCalled != nil { + fdm.AddCheckpointCalled(nonce, round, hash) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (fdm *ForkDetectorMock) IsInterfaceNil() bool { return fdm == nil diff --git a/factory/mock/headerSigVerifierStub.go b/factory/mock/headerSigVerifierStub.go deleted file mode 100644 index 03a7e9b2658..00000000000 --- a/factory/mock/headerSigVerifierStub.go +++ /dev/null @@ -1,49 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-core-go/data" - -// HeaderSigVerifierStub - -type HeaderSigVerifierStub struct { - VerifyRandSeedAndLeaderSignatureCalled func(header data.HeaderHandler) error - VerifyRandSeedCalled func(header data.HeaderHandler) error - VerifyLeaderSignatureCalled func(header data.HeaderHandler) error - VerifySignatureCalled func(header data.HeaderHandler) error -} - -// VerifyRandSeed - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeed(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedCalled != nil { - return hsvm.VerifyRandSeedCalled(header) - } - - return nil -} - -// VerifyLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyLeaderSignatureCalled != nil { - return hsvm.VerifyLeaderSignatureCalled(header) - } - return nil -} - -// VerifyRandSeedAndLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedAndLeaderSignatureCalled != nil { - return hsvm.VerifyRandSeedAndLeaderSignatureCalled(header) - } - return nil -} - -// VerifySignature - -func (hsvm *HeaderSigVerifierStub) VerifySignature(header data.HeaderHandler) error { - if hsvm.VerifySignatureCalled != nil { - return hsvm.VerifySignatureCalled(header) - } - return nil -} - -// IsInterfaceNil - -func (hsvm *HeaderSigVerifierStub) IsInterfaceNil() bool { - return hsvm == nil -} diff --git a/factory/mock/p2pAntifloodHandlerStub.go b/factory/mock/p2pAntifloodHandlerStub.go index bda3da406d5..2a6a10fbde8 100644 --- a/factory/mock/p2pAntifloodHandlerStub.go +++ b/factory/mock/p2pAntifloodHandlerStub.go @@ -16,6 +16,7 @@ type P2PAntifloodHandlerStub struct { SetDebuggerCalled func(debugger process.AntifloodDebugger) error BlacklistPeerCalled func(peer core.PeerID, reason string, duration time.Duration) IsOriginatorEligibleForTopicCalled func(pid core.PeerID, topic string) error + SetConsensusSizeNotifierCalled func(subscriber process.ChainParametersSubscriber, shardID uint32) } // CanProcessMessage - @@ -42,10 +43,10 @@ func (p2pahs *P2PAntifloodHandlerStub) CanProcessMessagesOnTopic(peer core.PeerI return p2pahs.CanProcessMessagesOnTopicCalled(peer, topic, numMessages, totalSize, sequence) } -// ApplyConsensusSize - -func (p2pahs *P2PAntifloodHandlerStub) ApplyConsensusSize(size int) { - if p2pahs.ApplyConsensusSizeCalled != nil { - p2pahs.ApplyConsensusSizeCalled(size) +// SetConsensusSizeNotifier - +func (p2pahs *P2PAntifloodHandlerStub) SetConsensusSizeNotifier(subscriber process.ChainParametersSubscriber, shardID uint32) { + if p2pahs.SetConsensusSizeNotifierCalled != nil { + p2pahs.SetConsensusSizeNotifierCalled(subscriber, shardID) } } diff --git a/factory/mock/rounderMock.go b/factory/mock/rounderMock.go index 0cdd4ab1bde..4a56cb166aa 100644 --- a/factory/mock/rounderMock.go +++ b/factory/mock/rounderMock.go @@ -19,6 +19,14 @@ func (rndm *RoundHandlerMock) BeforeGenesis() bool { return false } +// RevertOneRound - +func (rndm *RoundHandlerMock) RevertOneRound() { + rndm.mutRoundHandler.Lock() + rndm.RoundIndex-- + rndm.RoundTimeStamp = rndm.RoundTimeStamp.Add(-rndm.RoundTimeDuration) + rndm.mutRoundHandler.Unlock() +} + // Index - func (rndm *RoundHandlerMock) Index() int64 { rndm.mutRoundHandler.RLock() diff --git a/factory/mock/stateComponentsHolderStub.go b/factory/mock/stateComponentsHolderStub.go index c851fdc6dac..e6b6b6b86cf 100644 --- a/factory/mock/stateComponentsHolderStub.go +++ b/factory/mock/stateComponentsHolderStub.go @@ -14,6 +14,7 @@ type StateComponentsHolderStub struct { TriesContainerCalled func() common.TriesHolder TrieStorageManagersCalled func() map[string]common.StorageManager MissingTrieNodesNotifierCalled func() common.MissingTrieNodesNotifier + TrieLeavesRetrieverCalled func() common.TrieLeavesRetriever } // PeerAccounts - @@ -79,6 +80,14 @@ func (s *StateComponentsHolderStub) MissingTrieNodesNotifier() common.MissingTri return nil } +// TrieLeavesRetriever - +func (s *StateComponentsHolderStub) TrieLeavesRetriever() common.TrieLeavesRetriever { + if s.TrieLeavesRetrieverCalled != nil { + return s.TrieLeavesRetrieverCalled() + } + return nil +} + // Close - func (s *StateComponentsHolderStub) Close() error { return nil diff --git a/factory/mock/testdata/nodesSetupMock.json b/factory/mock/testdata/nodesSetupMock.json deleted file mode 100644 index 905496ad7c3..00000000000 --- a/factory/mock/testdata/nodesSetupMock.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "startTime": 0, - "roundDuration": 4000, - "consensusGroupSize": 1, - "minNodesPerShard": 1, - "metaChainActive" : true, - "metaChainConsensusGroupSize" : 1, - "metaChainMinNodes" : 1, - "hysteresis": 0, - "adaptivity": false, - "initialNodes": [ - { - "pubkey": "227a5a5ec0c58171b7f4ee9ecc304ea7b176fb626741a25c967add76d6cd361d6995929f9b60a96237381091cefb1b061225e5bb930b40494a5ac9d7524fd67dfe478e5ccd80f17b093cff5722025761fb0217c39dbd5ae45e01eb5a3113be93", - "address": "erd1ulhw20j7jvgfgak5p05kv667k5k9f320sgef5ayxkt9784ql0zssrzyhjp" - }, - { - "pubkey": "ef9522d654bc08ebf2725468f41a693aa7f3cf1cb93922cff1c8c81fba78274016010916f4a7e5b0855c430a724a2d0b3acd1fe8e61e37273a17d58faa8c0d3ef6b883a33ec648950469a1e9757b978d9ae662a019068a401cff56eea059fd08", - "address": "erd17c4fs6mz2aa2hcvva2jfxdsrdknu4220496jmswer9njznt22eds0rxlr4" - }, - { - "pubkey": "e91ab494cedd4da346f47aaa1a3e792bea24fb9f6cc40d3546bc4ca36749b8bfb0164e40dbad2195a76ee0fd7fb7da075ecbf1b35a2ac20638d53ea5520644f8c16952225c48304bb202867e2d71d396bff5a5971f345bcfe32c7b6b0ca34c84", - "address": "erd10d2gufxesrp8g409tzxljlaefhs0rsgjle3l7nq38de59txxt8csj54cd3" - } - ] -} diff --git a/factory/mock/testdata/nodesSetupMockInvalidRound.json b/factory/mock/testdata/nodesSetupMockInvalidRound.json deleted file mode 100644 index df96538e573..00000000000 --- a/factory/mock/testdata/nodesSetupMockInvalidRound.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "startTime": 0, - "roundDuration": 500, - "consensusGroupSize": 1, - "minNodesPerShard": 1, - "metaChainActive": true, - "metaChainConsensusGroupSize": 1, - "metaChainMinNodes": 1, - "hysteresis": 0, - "adaptivity": false, - "initialNodes": [ - { - "pubkey": "227a5a5ec0c58171b7f4ee9ecc304ea7b176fb626741a25c967add76d6cd361d6995929f9b60a96237381091cefb1b061225e5bb930b40494a5ac9d7524fd67dfe478e5ccd80f17b093cff5722025761fb0217c39dbd5ae45e01eb5a3113be93", - "address": "erd1ulhw20j7jvgfgak5p05kv667k5k9f320sgef5ayxkt9784ql0zssrzyhjp" - }, - { - "pubkey": "ef9522d654bc08ebf2725468f41a693aa7f3cf1cb93922cff1c8c81fba78274016010916f4a7e5b0855c430a724a2d0b3acd1fe8e61e37273a17d58faa8c0d3ef6b883a33ec648950469a1e9757b978d9ae662a019068a401cff56eea059fd08", - "address": "erd17c4fs6mz2aa2hcvva2jfxdsrdknu4220496jmswer9njznt22eds0rxlr4" - }, - { - "pubkey": "e91ab494cedd4da346f47aaa1a3e792bea24fb9f6cc40d3546bc4ca36749b8bfb0164e40dbad2195a76ee0fd7fb7da075ecbf1b35a2ac20638d53ea5520644f8c16952225c48304bb202867e2d71d396bff5a5971f345bcfe32c7b6b0ca34c84", - "address": "erd10d2gufxesrp8g409tzxljlaefhs0rsgjle3l7nq38de59txxt8csj54cd3" - } - ] -} diff --git a/factory/peerSignatureHandler/peerSignatureHandler_test.go b/factory/peerSignatureHandler/peerSignatureHandler_test.go index 15395f65379..9f01857b73d 100644 --- a/factory/peerSignatureHandler/peerSignatureHandler_test.go +++ b/factory/peerSignatureHandler/peerSignatureHandler_test.go @@ -7,11 +7,12 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/assert" + errorsErd "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/factory/peerSignatureHandler" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" - "github.com/stretchr/testify/assert" ) func TestNewPeerSignatureHandler_NilCacherShouldErr(t *testing.T) { @@ -31,7 +32,7 @@ func TestNewPeerSignatureHandler_NilSingleSignerShouldErr(t *testing.T) { t.Parallel() peerSigHandler, err := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), nil, &cryptoMocks.KeyGenStub{}, ) @@ -44,7 +45,7 @@ func TestNewPeerSignatureHandler_NilKeyGenShouldErr(t *testing.T) { t.Parallel() peerSigHandler, err := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, nil, ) @@ -57,7 +58,7 @@ func TestNewPeerSignatureHandler_OkParamsShouldWork(t *testing.T) { t.Parallel() peerSigHandler, err := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, &cryptoMocks.KeyGenStub{}, ) @@ -70,7 +71,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureInvalidPk(t *testing.T) { t.Parallel() peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, &cryptoMocks.KeyGenStub{}, ) @@ -83,7 +84,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureInvalidPID(t *testing.T) { t.Parallel() peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, &cryptoMocks.KeyGenStub{}, ) @@ -96,7 +97,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureInvalidSignature(t *testing.T) t.Parallel() peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, &cryptoMocks.KeyGenStub{}, ) @@ -116,7 +117,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureCantGetPubKeyBytes(t *testing.T } peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, keyGen, ) @@ -133,7 +134,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureSigNotFoundInCache(t *testing.T pid := "dummy peer" sig := []byte("signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() keyGen := &cryptoMocks.KeyGenStub{ PublicKeyFromByteArrayStub: func(b []byte) (crypto.PublicKey, error) { return &cryptoMocks.PublicKeyStub{ @@ -179,7 +180,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureWrongEntryInCache(t *testing.T) pid := "dummy peer" sig := []byte("signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() cache.Put(pk, wrongType, len(wrongType)) keyGen := &cryptoMocks.KeyGenStub{ @@ -228,7 +229,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureNewPidAndSig(t *testing.T) { newPid := core.PeerID("new dummy peer") newSig := []byte("new sig") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() keyGen := &cryptoMocks.KeyGenStub{ PublicKeyFromByteArrayStub: func(b []byte) (crypto.PublicKey, error) { return &cryptoMocks.PublicKeyStub{ @@ -277,7 +278,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureDifferentPid(t *testing.T) { sig := []byte("signature") newPid := core.PeerID("new dummy peer") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() keyGen := &cryptoMocks.KeyGenStub{ PublicKeyFromByteArrayStub: func(b []byte) (crypto.PublicKey, error) { return &cryptoMocks.PublicKeyStub{ @@ -317,7 +318,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureDifferentSig(t *testing.T) { sig := []byte("signature") newSig := []byte("new signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() keyGen := &cryptoMocks.KeyGenStub{ PublicKeyFromByteArrayStub: func(b []byte) (crypto.PublicKey, error) { return &cryptoMocks.PublicKeyStub{ @@ -356,7 +357,7 @@ func TestPeerSignatureHandler_VerifyPeerSignatureGetFromCache(t *testing.T) { pid := core.PeerID("dummy peer") sig := []byte("signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() keyGen := &cryptoMocks.KeyGenStub{ PublicKeyFromByteArrayStub: func(b []byte) (crypto.PublicKey, error) { return &cryptoMocks.PublicKeyStub{ @@ -399,7 +400,7 @@ func TestPeerSignatureHandler_GetPeerSignatureErrInConvertingPrivateKeyToByteArr pid := []byte("dummy peer") peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler( - testscommon.NewCacherMock(), + cache.NewCacherMock(), &cryptoMocks.SingleSignerStub{}, &cryptoMocks.KeyGenStub{}, ) @@ -422,7 +423,7 @@ func TestPeerSignatureHandler_GetPeerSignatureNotPresentInCache(t *testing.T) { pid := []byte("dummy peer") sig := []byte("signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() singleSigner := &cryptoMocks.SingleSignerStub{ SignCalled: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { signCalled = true @@ -465,7 +466,7 @@ func TestPeerSignatureHandler_GetPeerSignatureWrongEntryInCache(t *testing.T) { sig := []byte("signature") wrongEntry := []byte("wrong entry") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() singleSigner := &cryptoMocks.SingleSignerStub{ SignCalled: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { signCalled = true @@ -511,7 +512,7 @@ func TestPeerSignatureHandler_GetPeerSignatureDifferentPidInCache(t *testing.T) sig := []byte("signature") newSig := []byte("new signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() singleSigner := &cryptoMocks.SingleSignerStub{ SignCalled: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { signCalled = true @@ -555,7 +556,7 @@ func TestPeerSignatureHandler_GetPeerSignatureGetFromCache(t *testing.T) { pid := []byte("dummy peer") sig := []byte("signature") - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() singleSigner := &cryptoMocks.SingleSignerStub{ SignCalled: func(private crypto.PrivateKey, msg []byte) ([]byte, error) { return nil, nil diff --git a/factory/processing/blockProcessorCreator.go b/factory/processing/blockProcessorCreator.go index b4ebb0915b8..bde320e0338 100644 --- a/factory/processing/blockProcessorCreator.go +++ b/factory/processing/blockProcessorCreator.go @@ -1038,6 +1038,7 @@ func (pcf *processComponentsFactory) createOutportDataProvider( MbsStorer: mbsStorer, EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), ExecutionOrderGetter: pcf.txExecutionOrderHandler, + ProofsPool: pcf.data.Datapool().Proofs(), }) } diff --git a/factory/processing/blockProcessorCreator_test.go b/factory/processing/blockProcessorCreator_test.go index e80b62491d2..f56cc9a0d24 100644 --- a/factory/processing/blockProcessorCreator_test.go +++ b/factory/processing/blockProcessorCreator_test.go @@ -8,6 +8,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" dataComp "github.com/multiversx/mx-chain-go/factory/data" @@ -26,8 +29,6 @@ import ( storageManager "github.com/multiversx/mx-chain-go/testscommon/storage" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/require" ) func Test_newBlockProcessorCreatorForShard(t *testing.T) { diff --git a/factory/processing/processComponents.go b/factory/processing/processComponents.go index eb04fc9a4c5..3cff8f8f0b6 100644 --- a/factory/processing/processComponents.go +++ b/factory/processing/processComponents.go @@ -57,6 +57,7 @@ import ( "github.com/multiversx/mx-chain-go/process/factory/interceptorscontainer" "github.com/multiversx/mx-chain-go/process/headerCheck" "github.com/multiversx/mx-chain-go/process/heartbeat/validator" + interceptorFactory "github.com/multiversx/mx-chain-go/process/interceptors/factory" "github.com/multiversx/mx-chain-go/process/peer" "github.com/multiversx/mx-chain-go/process/receipts" "github.com/multiversx/mx-chain-go/process/smartContract" @@ -133,6 +134,7 @@ type processComponents struct { receiptsRepository mainFactory.ReceiptsRepository sentSignaturesTracker process.SentSignaturesTracker epochSystemSCProcessor process.EpochStartSystemSCProcessor + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory epochStartTriggerHanlder epochStart.TriggerHandler } @@ -209,6 +211,8 @@ type processComponentsFactory struct { genesisNonce uint64 genesisRound uint64 + + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // NewProcessComponentsFactory will return a new instance of processComponentsFactory @@ -218,37 +222,43 @@ func NewProcessComponentsFactory(args ProcessComponentsFactoryArgs) (*processCom return nil, err } + interceptedDataVerifierFactory := interceptorFactory.NewInterceptedDataVerifierFactory(interceptorFactory.InterceptedDataVerifierFactoryArgs{ + CacheSpan: time.Duration(args.Config.InterceptedDataVerifier.CacheSpanInSec) * time.Second, + CacheExpiry: time.Duration(args.Config.InterceptedDataVerifier.CacheExpiryInSec) * time.Second, + }) + return &processComponentsFactory{ - config: args.Config, - epochConfig: args.EpochConfig, - prefConfigs: args.PrefConfigs, - importDBConfig: args.ImportDBConfig, - economicsConfig: args.EconomicsConfig, - accountsParser: args.AccountsParser, - smartContractParser: args.SmartContractParser, - gasSchedule: args.GasSchedule, - nodesCoordinator: args.NodesCoordinator, - data: args.Data, - coreData: args.CoreData, - crypto: args.Crypto, - state: args.State, - network: args.Network, - bootstrapComponents: args.BootstrapComponents, - statusComponents: args.StatusComponents, - requestedItemsHandler: args.RequestedItemsHandler, - whiteListHandler: args.WhiteListHandler, - whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, - maxRating: args.MaxRating, - systemSCConfig: args.SystemSCConfig, - importStartHandler: args.ImportStartHandler, - historyRepo: args.HistoryRepo, - epochNotifier: args.CoreData.EpochNotifier(), - statusCoreComponents: args.StatusCoreComponents, - flagsConfig: args.FlagsConfig, - txExecutionOrderHandler: args.TxExecutionOrderHandler, - genesisNonce: args.GenesisNonce, - genesisRound: args.GenesisRound, - roundConfig: args.RoundConfig, + config: args.Config, + epochConfig: args.EpochConfig, + prefConfigs: args.PrefConfigs, + importDBConfig: args.ImportDBConfig, + economicsConfig: args.EconomicsConfig, + accountsParser: args.AccountsParser, + smartContractParser: args.SmartContractParser, + gasSchedule: args.GasSchedule, + nodesCoordinator: args.NodesCoordinator, + data: args.Data, + coreData: args.CoreData, + crypto: args.Crypto, + state: args.State, + network: args.Network, + bootstrapComponents: args.BootstrapComponents, + statusComponents: args.StatusComponents, + requestedItemsHandler: args.RequestedItemsHandler, + whiteListHandler: args.WhiteListHandler, + whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, + maxRating: args.MaxRating, + systemSCConfig: args.SystemSCConfig, + importStartHandler: args.ImportStartHandler, + historyRepo: args.HistoryRepo, + epochNotifier: args.CoreData.EpochNotifier(), + statusCoreComponents: args.StatusCoreComponents, + flagsConfig: args.FlagsConfig, + txExecutionOrderHandler: args.TxExecutionOrderHandler, + genesisNonce: args.GenesisNonce, + genesisRound: args.GenesisRound, + roundConfig: args.RoundConfig, + interceptedDataVerifierFactory: interceptedDataVerifierFactory, }, nil } @@ -285,6 +295,10 @@ func (pcf *processComponentsFactory) Create() (*processComponents, error) { SingleSigVerifier: pcf.crypto.BlockSigner(), KeyGen: pcf.crypto.BlockSignKeyGen(), FallbackHeaderValidator: fallbackHeaderValidator, + EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), + HeadersPool: pcf.data.Datapool().Headers(), + ProofsPool: pcf.data.Datapool().Proofs(), + StorageService: pcf.data.StorageService(), } headerSigVerifier, err := headerCheck.NewHeaderSigVerifier(argsHeaderSig) if err != nil { @@ -454,8 +468,9 @@ func (pcf *processComponentsFactory) Create() (*processComponents, error) { } argsHeaderValidator := block.ArgsHeaderValidator{ - Hasher: pcf.coreData.Hasher(), - Marshalizer: pcf.coreData.InternalMarshalizer(), + Hasher: pcf.coreData.Hasher(), + Marshalizer: pcf.coreData.InternalMarshalizer(), + EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), } headerValidator, err := block.NewHeaderValidator(argsHeaderValidator) if err != nil { @@ -763,6 +778,7 @@ func (pcf *processComponentsFactory) Create() (*processComponents, error) { accountsParser: pcf.accountsParser, receiptsRepository: receiptsRepository, sentSignaturesTracker: sentSignaturesTracker, + interceptedDataVerifierFactory: pcf.interceptedDataVerifierFactory, epochStartTriggerHanlder: epochStartTrigger, }, nil } @@ -811,8 +827,9 @@ func (pcf *processComponentsFactory) newEpochStartTrigger(requestHandler epochSt shardCoordinator := pcf.bootstrapComponents.ShardCoordinator() if shardCoordinator.SelfId() < shardCoordinator.NumberOfShards() { argsHeaderValidator := block.ArgsHeaderValidator{ - Hasher: pcf.coreData.Hasher(), - Marshalizer: pcf.coreData.InternalMarshalizer(), + Hasher: pcf.coreData.Hasher(), + Marshalizer: pcf.coreData.InternalMarshalizer(), + EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), } headerValidator, err := block.NewHeaderValidator(argsHeaderValidator) if err != nil { @@ -1066,7 +1083,7 @@ func (pcf *processComponentsFactory) saveShardBlock(genesisBlockHash []byte, mar log.Error("error storing genesis shardblock", "error", errNotCritical.Error()) } - hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(shardID) + hdrNonceHashDataUnit := dataRetriever.GetHdrNonceHashDataUnit(shardID) errNotCritical = pcf.data.StorageService().Put(hdrNonceHashDataUnit, nonceToByteSlice, genesisBlockHash) if errNotCritical != nil { log.Error("error storing genesis shard header (nonce-hash)", "error", errNotCritical.Error()) @@ -1326,17 +1343,21 @@ func (pcf *processComponentsFactory) newBlockTracker( ) (process.BlockTracker, error) { shardCoordinator := pcf.bootstrapComponents.ShardCoordinator() argBaseTracker := track.ArgBaseTracker{ - Hasher: pcf.coreData.Hasher(), - HeaderValidator: headerValidator, - Marshalizer: pcf.coreData.InternalMarshalizer(), - RequestHandler: requestHandler, - RoundHandler: pcf.coreData.RoundHandler(), - ShardCoordinator: shardCoordinator, - Store: pcf.data.StorageService(), - StartHeaders: genesisBlocks, - PoolsHolder: pcf.data.Datapool(), - WhitelistHandler: pcf.whiteListHandler, - FeeHandler: pcf.coreData.EconomicsData(), + Hasher: pcf.coreData.Hasher(), + HeaderValidator: headerValidator, + Marshalizer: pcf.coreData.InternalMarshalizer(), + RequestHandler: requestHandler, + RoundHandler: pcf.coreData.RoundHandler(), + ShardCoordinator: shardCoordinator, + Store: pcf.data.StorageService(), + StartHeaders: genesisBlocks, + PoolsHolder: pcf.data.Datapool(), + WhitelistHandler: pcf.whiteListHandler, + FeeHandler: pcf.coreData.EconomicsData(), + EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), + ProofsPool: pcf.data.Datapool().Proofs(), + IsImportDBMode: pcf.importDBConfig.IsImportDBMode, + EpochChangeGracePeriodHandler: pcf.coreData.EpochChangeGracePeriodHandler(), } if shardCoordinator.SelfId() < shardCoordinator.NumberOfShards() { @@ -1474,6 +1495,7 @@ func (pcf *processComponentsFactory) newRequestersContainerFactory( FullArchivePreferredPeersHolder: pcf.network.FullArchivePreferredPeersHolderHandler(), PeersRatingHandler: pcf.network.PeersRatingHandler(), SizeCheckDelta: pcf.config.Marshalizer.SizeCheckDelta, + EnableEpochsHandler: pcf.coreData.EnableEpochsHandler(), } if shardCoordinator.SelfId() < shardCoordinator.NumberOfShards() { @@ -1668,36 +1690,37 @@ func (pcf *processComponentsFactory) newShardInterceptorContainerFactory( ) (process.InterceptorsContainerFactory, process.TimeCacher, error) { headerBlackList := cache.NewTimeCache(timeSpanForBadHeaders) shardInterceptorsContainerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: pcf.coreData, - CryptoComponents: pcf.crypto, - Accounts: pcf.state.AccountsAdapter(), - ShardCoordinator: pcf.bootstrapComponents.ShardCoordinator(), - NodesCoordinator: pcf.nodesCoordinator, - MainMessenger: pcf.network.NetworkMessenger(), - FullArchiveMessenger: pcf.network.FullArchiveNetworkMessenger(), - Store: pcf.data.StorageService(), - DataPool: pcf.data.Datapool(), - MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, - TxFeeHandler: pcf.coreData.EconomicsData(), - BlockBlackList: headerBlackList, - HeaderSigVerifier: headerSigVerifier, - HeaderIntegrityVerifier: headerIntegrityVerifier, - ValidityAttester: validityAttester, - EpochStartTrigger: epochStartTrigger, - WhiteListHandler: pcf.whiteListHandler, - WhiteListerVerifiedTxs: pcf.whiteListerVerifiedTxs, - AntifloodHandler: pcf.network.InputAntiFloodHandler(), - ArgumentsParser: smartContract.NewArgumentParser(), - PreferredPeersHolder: pcf.network.PreferredPeersHolderHandler(), - SizeCheckDelta: pcf.config.Marshalizer.SizeCheckDelta, - RequestHandler: requestHandler, - PeerSignatureHandler: pcf.crypto.PeerSignatureHandler(), - SignaturesHandler: pcf.network.NetworkMessenger(), - HeartbeatExpiryTimespanInSec: pcf.config.HeartbeatV2.HeartbeatExpiryTimespanInSec, - MainPeerShardMapper: mainPeerShardMapper, - FullArchivePeerShardMapper: fullArchivePeerShardMapper, - HardforkTrigger: hardforkTrigger, - NodeOperationMode: nodeOperationMode, + CoreComponents: pcf.coreData, + CryptoComponents: pcf.crypto, + Accounts: pcf.state.AccountsAdapter(), + ShardCoordinator: pcf.bootstrapComponents.ShardCoordinator(), + NodesCoordinator: pcf.nodesCoordinator, + MainMessenger: pcf.network.NetworkMessenger(), + FullArchiveMessenger: pcf.network.FullArchiveNetworkMessenger(), + Store: pcf.data.StorageService(), + DataPool: pcf.data.Datapool(), + MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, + TxFeeHandler: pcf.coreData.EconomicsData(), + BlockBlackList: headerBlackList, + HeaderSigVerifier: headerSigVerifier, + HeaderIntegrityVerifier: headerIntegrityVerifier, + ValidityAttester: validityAttester, + EpochStartTrigger: epochStartTrigger, + WhiteListHandler: pcf.whiteListHandler, + WhiteListerVerifiedTxs: pcf.whiteListerVerifiedTxs, + AntifloodHandler: pcf.network.InputAntiFloodHandler(), + ArgumentsParser: smartContract.NewArgumentParser(), + PreferredPeersHolder: pcf.network.PreferredPeersHolderHandler(), + SizeCheckDelta: pcf.config.Marshalizer.SizeCheckDelta, + RequestHandler: requestHandler, + PeerSignatureHandler: pcf.crypto.PeerSignatureHandler(), + SignaturesHandler: pcf.network.NetworkMessenger(), + HeartbeatExpiryTimespanInSec: pcf.config.HeartbeatV2.HeartbeatExpiryTimespanInSec, + MainPeerShardMapper: mainPeerShardMapper, + FullArchivePeerShardMapper: fullArchivePeerShardMapper, + HardforkTrigger: hardforkTrigger, + NodeOperationMode: nodeOperationMode, + InterceptedDataVerifierFactory: pcf.interceptedDataVerifierFactory, } interceptorContainerFactory, err := interceptorscontainer.NewShardInterceptorsContainerFactory(shardInterceptorsContainerFactoryArgs) @@ -1721,36 +1744,37 @@ func (pcf *processComponentsFactory) newMetaInterceptorContainerFactory( ) (process.InterceptorsContainerFactory, process.TimeCacher, error) { headerBlackList := cache.NewTimeCache(timeSpanForBadHeaders) metaInterceptorsContainerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: pcf.coreData, - CryptoComponents: pcf.crypto, - ShardCoordinator: pcf.bootstrapComponents.ShardCoordinator(), - NodesCoordinator: pcf.nodesCoordinator, - MainMessenger: pcf.network.NetworkMessenger(), - FullArchiveMessenger: pcf.network.FullArchiveNetworkMessenger(), - Store: pcf.data.StorageService(), - DataPool: pcf.data.Datapool(), - Accounts: pcf.state.AccountsAdapter(), - MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, - TxFeeHandler: pcf.coreData.EconomicsData(), - BlockBlackList: headerBlackList, - HeaderSigVerifier: headerSigVerifier, - HeaderIntegrityVerifier: headerIntegrityVerifier, - ValidityAttester: validityAttester, - EpochStartTrigger: epochStartTrigger, - WhiteListHandler: pcf.whiteListHandler, - WhiteListerVerifiedTxs: pcf.whiteListerVerifiedTxs, - AntifloodHandler: pcf.network.InputAntiFloodHandler(), - ArgumentsParser: smartContract.NewArgumentParser(), - SizeCheckDelta: pcf.config.Marshalizer.SizeCheckDelta, - PreferredPeersHolder: pcf.network.PreferredPeersHolderHandler(), - RequestHandler: requestHandler, - PeerSignatureHandler: pcf.crypto.PeerSignatureHandler(), - SignaturesHandler: pcf.network.NetworkMessenger(), - HeartbeatExpiryTimespanInSec: pcf.config.HeartbeatV2.HeartbeatExpiryTimespanInSec, - MainPeerShardMapper: mainPeerShardMapper, - FullArchivePeerShardMapper: fullArchivePeerShardMapper, - HardforkTrigger: hardforkTrigger, - NodeOperationMode: nodeOperationMode, + CoreComponents: pcf.coreData, + CryptoComponents: pcf.crypto, + ShardCoordinator: pcf.bootstrapComponents.ShardCoordinator(), + NodesCoordinator: pcf.nodesCoordinator, + MainMessenger: pcf.network.NetworkMessenger(), + FullArchiveMessenger: pcf.network.FullArchiveNetworkMessenger(), + Store: pcf.data.StorageService(), + DataPool: pcf.data.Datapool(), + Accounts: pcf.state.AccountsAdapter(), + MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, + TxFeeHandler: pcf.coreData.EconomicsData(), + BlockBlackList: headerBlackList, + HeaderSigVerifier: headerSigVerifier, + HeaderIntegrityVerifier: headerIntegrityVerifier, + ValidityAttester: validityAttester, + EpochStartTrigger: epochStartTrigger, + WhiteListHandler: pcf.whiteListHandler, + WhiteListerVerifiedTxs: pcf.whiteListerVerifiedTxs, + AntifloodHandler: pcf.network.InputAntiFloodHandler(), + ArgumentsParser: smartContract.NewArgumentParser(), + SizeCheckDelta: pcf.config.Marshalizer.SizeCheckDelta, + PreferredPeersHolder: pcf.network.PreferredPeersHolderHandler(), + RequestHandler: requestHandler, + PeerSignatureHandler: pcf.crypto.PeerSignatureHandler(), + SignaturesHandler: pcf.network.NetworkMessenger(), + HeartbeatExpiryTimespanInSec: pcf.config.HeartbeatV2.HeartbeatExpiryTimespanInSec, + MainPeerShardMapper: mainPeerShardMapper, + FullArchivePeerShardMapper: fullArchivePeerShardMapper, + HardforkTrigger: hardforkTrigger, + NodeOperationMode: nodeOperationMode, + InterceptedDataVerifierFactory: pcf.interceptedDataVerifierFactory, } interceptorContainerFactory, err := interceptorscontainer.NewMetaInterceptorsContainerFactory(metaInterceptorsContainerFactoryArgs) @@ -1767,10 +1791,22 @@ func (pcf *processComponentsFactory) newForkDetector( ) (process.ForkDetector, error) { shardCoordinator := pcf.bootstrapComponents.ShardCoordinator() if shardCoordinator.SelfId() < shardCoordinator.NumberOfShards() { - return sync.NewShardForkDetector(pcf.coreData.RoundHandler(), headerBlackList, blockTracker, pcf.coreData.GenesisNodesSetup().GetStartTime()) + return sync.NewShardForkDetector( + pcf.coreData.RoundHandler(), + headerBlackList, + blockTracker, + pcf.coreData.GenesisNodesSetup().GetStartTime(), + pcf.coreData.EnableEpochsHandler(), + pcf.data.Datapool().Proofs()) } if shardCoordinator.SelfId() == core.MetachainShardId { - return sync.NewMetaForkDetector(pcf.coreData.RoundHandler(), headerBlackList, blockTracker, pcf.coreData.GenesisNodesSetup().GetStartTime()) + return sync.NewMetaForkDetector( + pcf.coreData.RoundHandler(), + headerBlackList, + blockTracker, + pcf.coreData.GenesisNodesSetup().GetStartTime(), + pcf.coreData.EnableEpochsHandler(), + pcf.data.Datapool().Proofs()) } return nil, errors.New("could not create fork detector") @@ -1850,6 +1886,7 @@ func (pcf *processComponentsFactory) createExportFactoryHandler( NumConcurrentTrieSyncers: pcf.config.TrieSync.NumConcurrentTrieSyncers, TrieSyncerVersion: pcf.config.TrieSync.TrieSyncerVersion, NodeOperationMode: nodeOperationMode, + InterceptedDataVerifierFactory: pcf.interceptedDataVerifierFactory, } return updateFactory.NewExportHandlerFactory(argsExporter) } @@ -2047,6 +2084,9 @@ func (pc *processComponents) Close() error { if !check.IfNil(pc.txsSender) { log.LogIfError(pc.txsSender.Close()) } + if !check.IfNil(pc.interceptedDataVerifierFactory) { + log.LogIfError(pc.interceptedDataVerifierFactory.Close()) + } return nil } diff --git a/factory/processing/processComponents_test.go b/factory/processing/processComponents_test.go index e714c278381..157088fd9f7 100644 --- a/factory/processing/processComponents_test.go +++ b/factory/processing/processComponents_test.go @@ -17,8 +17,11 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing/blake2b" "github.com/multiversx/mx-chain-core-go/hashing/keccak" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/factory" + "github.com/multiversx/mx-chain-go/common/graceperiod" disabledStatistics "github.com/multiversx/mx-chain-go/common/statistics/disabled" "github.com/multiversx/mx-chain-go/config" errorsMx "github.com/multiversx/mx-chain-go/errors" @@ -55,7 +58,6 @@ import ( testState "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" updateMocks "github.com/multiversx/mx-chain-go/update/mock" - "github.com/stretchr/testify/require" ) const ( @@ -78,7 +80,7 @@ var ( ) func createMockProcessComponentsFactoryArgs() processComp.ProcessComponentsFactoryArgs { - + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) args := processComp.ProcessComponentsFactoryArgs{ Config: testscommon.GetGeneralConfig(), EpochConfig: config.EpochConfig{ @@ -208,22 +210,23 @@ func createMockProcessComponentsFactoryArgs() processComp.ProcessComponentsFacto return big.NewInt(100000000) }, }, - Hash: blake2b.NewBlake2b(), - TxVersionCheckHandler: &testscommon.TxVersionCheckerStub{}, - RatingHandler: &testscommon.RaterMock{}, - EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, - EpochNotifierWithConfirm: &updateMocks.EpochStartNotifierStub{}, - RoundHandlerField: &testscommon.RoundHandlerMock{}, - RoundChangeNotifier: &epochNotifier.RoundNotifierStub{}, - ChanStopProcess: make(chan endProcess.ArgEndProcess, 1), - TxSignHasherField: keccak.NewKeccak(), - HardforkTriggerPubKeyField: []byte("hardfork pub key"), - WasmVMChangeLockerInternal: &sync.RWMutex{}, - NodeTypeProviderField: &nodeTypeProviderMock.NodeTypeProviderStub{}, - RatingsConfig: &testscommon.RatingsInfoMock{}, - PathHdl: &testscommon.PathManagerStub{}, - ProcessStatusHandlerInternal: &testscommon.ProcessStatusHandlerStub{}, + EpochChangeGracePeriodHandlerField: gracePeriod, + Hash: blake2b.NewBlake2b(), + TxVersionCheckHandler: &testscommon.TxVersionCheckerStub{}, + RatingHandler: &testscommon.RaterMock{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, + EpochNotifierWithConfirm: &updateMocks.EpochStartNotifierStub{}, + RoundHandlerField: &testscommon.RoundHandlerMock{}, + RoundChangeNotifier: &epochNotifier.RoundNotifierStub{}, + ChanStopProcess: make(chan endProcess.ArgEndProcess, 1), + TxSignHasherField: keccak.NewKeccak(), + HardforkTriggerPubKeyField: []byte("hardfork pub key"), + WasmVMChangeLockerInternal: &sync.RWMutex{}, + NodeTypeProviderField: &nodeTypeProviderMock.NodeTypeProviderStub{}, + RatingsConfig: &testscommon.RatingsInfoMock{}, + PathHdl: &testscommon.PathManagerStub{}, + ProcessStatusHandlerInternal: &testscommon.ProcessStatusHandlerStub{}, }, Crypto: &testsMocks.CryptoComponentsStub{ BlKeyGen: &cryptoMocks.KeyGenStub{}, diff --git a/factory/state/stateComponents.go b/factory/state/stateComponents.go index 8da3251e230..e09aae7b1c9 100644 --- a/factory/state/stateComponents.go +++ b/factory/state/stateComponents.go @@ -19,6 +19,7 @@ import ( "github.com/multiversx/mx-chain-go/state/storagePruningManager/evictionWaitingList" "github.com/multiversx/mx-chain-go/state/syncer" trieFactory "github.com/multiversx/mx-chain-go/trie/factory" + "github.com/multiversx/mx-chain-go/trie/leavesRetriever" ) // TODO: merge this with data components @@ -53,6 +54,7 @@ type stateComponents struct { triesContainer common.TriesHolder trieStorageManagers map[string]common.StorageManager missingTrieNodesNotifier common.MissingTrieNodesNotifier + trieLeavesRetriever common.TrieLeavesRetriever } // NewStateComponentsFactory will return a new instance of stateComponentsFactory @@ -100,6 +102,11 @@ func (scf *stateComponentsFactory) Create() (*stateComponents, error) { return nil, err } + trieLeavesRetriever, err := scf.createTrieLeavesRetriever(trieStorageManagers[dataRetriever.UserAccountsUnit.String()]) + if err != nil { + return nil, err + } + return &stateComponents{ peerAccounts: peerAdapter, accountsAdapter: accountsAdapter, @@ -108,9 +115,23 @@ func (scf *stateComponentsFactory) Create() (*stateComponents, error) { triesContainer: triesContainer, trieStorageManagers: trieStorageManagers, missingTrieNodesNotifier: syncer.NewMissingTrieNodesNotifier(), + trieLeavesRetriever: trieLeavesRetriever, }, nil } +func (scf *stateComponentsFactory) createTrieLeavesRetriever(trieStorage common.TrieStorageInteractor) (common.TrieLeavesRetriever, error) { + if !scf.config.TrieLeavesRetrieverConfig.Enabled { + return leavesRetriever.NewDisabledLeavesRetriever(), nil + } + + return leavesRetriever.NewLeavesRetriever( + trieStorage, + scf.core.InternalMarshalizer(), + scf.core.Hasher(), + scf.config.TrieLeavesRetrieverConfig.MaxSizeInBytes, + ) +} + func (scf *stateComponentsFactory) createSnapshotManager( accountFactory state.AccountFactory, stateMetrics state.StateMetrics, diff --git a/factory/state/stateComponentsHandler.go b/factory/state/stateComponentsHandler.go index 78271a28ffe..e84c1f8b3b5 100644 --- a/factory/state/stateComponentsHandler.go +++ b/factory/state/stateComponentsHandler.go @@ -93,6 +93,9 @@ func (msc *managedStateComponents) CheckSubcomponents() error { if check.IfNil(msc.missingTrieNodesNotifier) { return errors.ErrNilMissingTrieNodesNotifier } + if check.IfNil(msc.trieLeavesRetriever) { + return errors.ErrNilTrieLeavesRetriever + } return nil } @@ -214,6 +217,18 @@ func (msc *managedStateComponents) MissingTrieNodesNotifier() common.MissingTrie return msc.stateComponents.missingTrieNodesNotifier } +// TrieLeavesRetriever returns the trie leaves retriever +func (msc *managedStateComponents) TrieLeavesRetriever() common.TrieLeavesRetriever { + msc.mutStateComponents.RLock() + defer msc.mutStateComponents.RUnlock() + + if msc.stateComponents == nil { + return nil + } + + return msc.stateComponents.trieLeavesRetriever +} + // IsInterfaceNil returns true if the interface is nil func (msc *managedStateComponents) IsInterfaceNil() bool { return msc == nil diff --git a/factory/statusCore/statusCoreComponents.go b/factory/statusCore/statusCoreComponents.go index d32ee129a9d..b25cbca427e 100644 --- a/factory/statusCore/statusCoreComponents.go +++ b/factory/statusCore/statusCoreComponents.go @@ -163,6 +163,12 @@ func (sccf *statusCoreComponentsFactory) createStatusHandler() (core.AppStatusHa return nil, nil, nil, err } + err = sccf.coreComp.RatingsData().SetStatusHandler(handler) + if err != nil { + log.Debug("cannot set status handler to ratingsData", "error", err) + return nil, nil, nil, err + } + return handler, statusMetrics, persistentHandler, nil } diff --git a/factory/statusCore/statusCoreComponents_test.go b/factory/statusCore/statusCoreComponents_test.go index 0248500bbe7..3d3ff69c005 100644 --- a/factory/statusCore/statusCoreComponents_test.go +++ b/factory/statusCore/statusCoreComponents_test.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-go/factory/statusCore" "github.com/multiversx/mx-chain-go/integrationTests/mock" "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/testscommon" componentsMock "github.com/multiversx/mx-chain-go/testscommon/components" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/factory" @@ -86,7 +87,7 @@ func TestStatusCoreComponentsFactory_Create(t *testing.T) { require.Error(t, err) require.Nil(t, cc) }) - t.Run("SetStatusHandler fails should error", func(t *testing.T) { + t.Run("SetStatusHandler on economics data fails should error", func(t *testing.T) { t.Parallel() expectedErr := errors.New("expected error") @@ -106,6 +107,26 @@ func TestStatusCoreComponentsFactory_Create(t *testing.T) { require.Equal(t, expectedErr, err) require.Nil(t, cc) }) + t.Run("SetStatusHandler on ratings data fails should error", func(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("expected error") + coreCompStub := factory.NewCoreComponentsHolderStubFromRealComponent(componentsMock.GetCoreComponents()) + coreCompStub.RatingsDataCalled = func() process.RatingsInfoHandler { + return &testscommon.RatingsInfoMock{ + SetStatusHandlerCalled: func(statusHandler core.AppStatusHandler) error { + return expectedErr + }, + } + } + args := componentsMock.GetStatusCoreArgs(coreCompStub) + sccf, err := statusCore.NewStatusCoreComponentsFactory(args) + require.Nil(t, err) + + cc, err := sccf.Create() + require.Equal(t, expectedErr, err) + require.Nil(t, cc) + }) t.Run("should work", func(t *testing.T) { t.Parallel() diff --git a/fallback/headerValidator.go b/fallback/headerValidator.go index 8e2d0eda037..4b9110582b0 100644 --- a/fallback/headerValidator.go +++ b/fallback/headerValidator.go @@ -5,10 +5,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("fallback") @@ -45,28 +46,34 @@ func NewFallbackHeaderValidator( return hv, nil } -// ShouldApplyFallbackValidation returns if for the given header could be applied fallback validation or not -func (fhv *fallbackHeaderValidator) ShouldApplyFallbackValidation(headerHandler data.HeaderHandler) bool { - if check.IfNil(headerHandler) { - return false - } - if headerHandler.GetShardID() != core.MetachainShardId { +// ShouldApplyFallbackValidationForHeaderWith returns if for the given header data fallback validation could be applied or not +func (fhv *fallbackHeaderValidator) ShouldApplyFallbackValidationForHeaderWith(shardID uint32, startOfEpochBlock bool, round uint64, prevHeaderHash []byte) bool { + if shardID != core.MetachainShardId { return false } - if !headerHandler.IsStartOfEpochBlock() { + if !startOfEpochBlock { return false } - previousHeader, err := process.GetMetaHeader(headerHandler.GetPrevHash(), fhv.headersPool, fhv.marshalizer, fhv.storageService) + previousHeader, err := process.GetMetaHeader(prevHeaderHash, fhv.headersPool, fhv.marshalizer, fhv.storageService) if err != nil { log.Debug("ShouldApplyFallbackValidation", "GetMetaHeader", err.Error()) return false } - isRoundTooOld := int64(headerHandler.GetRound())-int64(previousHeader.GetRound()) >= common.MaxRoundsWithoutCommittedStartInEpochBlock + isRoundTooOld := int64(round)-int64(previousHeader.GetRound()) >= common.MaxRoundsWithoutCommittedStartInEpochBlock return isRoundTooOld } +// ShouldApplyFallbackValidation returns if for the given header could be applied fallback validation or not +func (fhv *fallbackHeaderValidator) ShouldApplyFallbackValidation(headerHandler data.HeaderHandler) bool { + if check.IfNil(headerHandler) { + return false + } + + return fhv.ShouldApplyFallbackValidationForHeaderWith(headerHandler.GetShardID(), headerHandler.IsStartOfEpochBlock(), headerHandler.GetRound(), headerHandler.GetPrevHash()) +} + // IsInterfaceNil returns true if there is no value under the interface func (fhv *fallbackHeaderValidator) IsInterfaceNil() bool { return fhv == nil diff --git a/genesis/process/disabled/requestHandler.go b/genesis/process/disabled/requestHandler.go index a1f26781b7d..d24590cc6d3 100644 --- a/genesis/process/disabled/requestHandler.go +++ b/genesis/process/disabled/requestHandler.go @@ -90,6 +90,14 @@ func (r *RequestHandler) RequestValidatorInfo(_ []byte) { func (r *RequestHandler) RequestValidatorsInfo(_ [][]byte) { } +// RequestEquivalentProofByHash does nothing +func (r *RequestHandler) RequestEquivalentProofByHash(_ uint32, _ []byte) { +} + +// RequestEquivalentProofByNonce does nothing +func (r *RequestHandler) RequestEquivalentProofByNonce(_ uint32, _ uint64) { +} + // IsInterfaceNil returns true if there is no value under the interface func (r *RequestHandler) IsInterfaceNil() bool { return r == nil diff --git a/go.mod b/go.mod index 382aa107994..78e8752d156 100644 --- a/go.mod +++ b/go.mod @@ -13,19 +13,21 @@ require ( github.com/google/gops v0.3.18 github.com/gorilla/websocket v1.5.3 github.com/klauspost/cpuid/v2 v2.2.9 + github.com/libp2p/go-libp2p v0.38.2 + github.com/libp2p/go-libp2p-pubsub v0.13.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/multiversx/mx-chain-communication-go v1.1.2-0.20250218164645-1f6964baffbe - github.com/multiversx/mx-chain-core-go v1.2.25-0.20250218161123-121084ae9840 - github.com/multiversx/mx-chain-crypto-go v1.2.13-0.20250218161752-9482d9a22234 - github.com/multiversx/mx-chain-es-indexer-go v1.7.17-0.20250218165903-7923d170f8f0 - github.com/multiversx/mx-chain-logger-go v1.0.16-0.20250218161408-6a0c19d0da48 - github.com/multiversx/mx-chain-scenario-go v1.5.1-0.20250218162624-877d8b9870a4 - github.com/multiversx/mx-chain-storage-go v1.0.20-0.20250218162234-85e60acebb43 - github.com/multiversx/mx-chain-vm-common-go v1.5.17-0.20250218162215-88938774627c - github.com/multiversx/mx-chain-vm-go v1.5.38-0.20250318180139-42d650e84043 - github.com/multiversx/mx-chain-vm-v1_2-go v1.2.69-0.20250220133402-01591d72f671 - github.com/multiversx/mx-chain-vm-v1_3-go v1.3.70-0.20250220133720-4abbb3b36387 - github.com/multiversx/mx-chain-vm-v1_4-go v1.4.99-0.20250220144348-9455d2a4e6e6 + github.com/multiversx/mx-chain-communication-go v1.2.1-0.20250520083403-3f2bad6d5476 + github.com/multiversx/mx-chain-core-go v1.3.2-0.20250520074139-18b645ad397a + github.com/multiversx/mx-chain-crypto-go v1.2.13-0.20250520075055-8ab2a164945d + github.com/multiversx/mx-chain-es-indexer-go v1.8.2-0.20250520083544-09915e4d9bae + github.com/multiversx/mx-chain-logger-go v1.0.16-0.20250520074859-b2faf3c90273 + github.com/multiversx/mx-chain-scenario-go v1.5.1-0.20250520075713-734e46b4c66d + github.com/multiversx/mx-chain-storage-go v1.0.20-0.20250520075958-65fd4c7bcaae + github.com/multiversx/mx-chain-vm-common-go v1.5.17-0.20250520075408-c94bee9ee163 + github.com/multiversx/mx-chain-vm-go v1.5.41-0.20250520080530-2838146363b4 + github.com/multiversx/mx-chain-vm-v1_2-go v1.2.69-0.20250520080927-410c413d962f + github.com/multiversx/mx-chain-vm-v1_3-go v1.3.70-0.20250520081414-edf8b75e054d + github.com/multiversx/mx-chain-vm-v1_4-go v1.4.99-0.20250520081749-516b5ae0e49c github.com/pelletier/go-toml v1.9.3 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.20.5 @@ -33,6 +35,7 @@ require ( github.com/stretchr/testify v1.10.0 github.com/urfave/cli v1.22.16 golang.org/x/crypto v0.32.0 + golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c gopkg.in/go-playground/validator.v8 v8.18.2 ) @@ -99,11 +102,9 @@ require ( github.com/libp2p/go-buffer-pool v0.1.0 // indirect github.com/libp2p/go-cidranger v1.1.0 // indirect github.com/libp2p/go-flow-metrics v0.2.0 // indirect - github.com/libp2p/go-libp2p v0.38.2 // indirect github.com/libp2p/go-libp2p-asn-util v0.4.1 // indirect github.com/libp2p/go-libp2p-kad-dht v0.29.0 // indirect github.com/libp2p/go-libp2p-kbucket v0.6.5 // indirect - github.com/libp2p/go-libp2p-pubsub v0.13.0 // indirect github.com/libp2p/go-libp2p-record v0.3.1 // indirect github.com/libp2p/go-libp2p-routing-helpers v0.7.4 // indirect github.com/libp2p/go-msgio v0.3.0 // indirect @@ -193,7 +194,6 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect golang.org/x/arch v0.8.0 // indirect - golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c // indirect golang.org/x/mod v0.22.0 // indirect golang.org/x/net v0.34.0 // indirect golang.org/x/sync v0.10.0 // indirect diff --git a/go.sum b/go.sum index 28c76ac2491..c2ed0c08cc3 100644 --- a/go.sum +++ b/go.sum @@ -87,7 +87,6 @@ github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c h1:pFUpOrbxDR github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c/go.mod h1:6UhI8N9EjYm1c2odKpFpAYeR8dsBeM7PtzQhRgxRr9U= github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc= github.com/decred/dcrd/crypto/blake256 v1.0.1 h1:7PltbUIQB7u/FfZ39+DGa/ShuMyJ5ilcvdfma9wOH6Y= -github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.1/go.mod h1:hyedUtir6IdtD/7lIxGeCxkaw7y45JueMRL4DIyJDKs= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= @@ -113,7 +112,6 @@ github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwU github.com/francoispqt/gojay v1.2.13 h1:d2m3sFjloqoIUQU3TsHBgj6qg/BVGlTBeHDUmyJnXKk= github.com/francoispqt/gojay v1.2.13/go.mod h1:ehT5mTG4ua4581f1++1WLG0vPdaA9HaiDsoyrBGkyDY= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= -github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= @@ -144,7 +142,6 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= @@ -198,7 +195,6 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -220,7 +216,6 @@ github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c h1:7lF+Vz0LqiRidnzC1Oq86fpX1q/iEv2KJdrCtttYjT4= -github.com/gopherjs/gopherjs v0.0.0-20190430165422-3e4dfb77656c/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= @@ -245,7 +240,6 @@ github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1: github.com/ipfs/boxo v0.27.4 h1:6nC8lY5GnR6whAbW88hFz6L13wZUj2vr5BRe3iTvYBI= github.com/ipfs/boxo v0.27.4/go.mod h1:qEIRrGNr0bitDedTCzyzBHxzNWqYmyuHgK8LG9Q83EM= github.com/ipfs/go-block-format v0.2.0 h1:ZqrkxBA2ICbDRbK8KJs/u0O3dlp6gmAuuXUJNiW1Ycs= -github.com/ipfs/go-block-format v0.2.0/go.mod h1:+jpL11nFx5A/SPpsoBn6Bzkra/zaArfSmsknbPMYgzM= github.com/ipfs/go-cid v0.5.0 h1:goEKKhaGm0ul11IHA7I6p1GmKz8kEYniqFopaB5Otwg= github.com/ipfs/go-cid v0.5.0/go.mod h1:0L7vmeNXpQpUS9vt+yEARkJ8rOg43DF3iPgn4GIN0mk= github.com/ipfs/go-datastore v0.6.0 h1:JKyz+Gvz1QEZw0LsX1IBn+JFCJQH4SJVFtM4uWU0Myk= @@ -253,14 +247,12 @@ github.com/ipfs/go-datastore v0.6.0/go.mod h1:rt5M3nNbSO/8q1t4LNkLyUwRs8HupMeN/8 github.com/ipfs/go-detect-race v0.0.1 h1:qX/xay2W3E4Q1U7d9lNs1sU9nvguX0a7319XbyQ6cOk= github.com/ipfs/go-detect-race v0.0.1/go.mod h1:8BNT7shDZPo99Q74BpGMK+4D8Mn4j46UU0LZ723meps= github.com/ipfs/go-ipfs-util v0.0.3 h1:2RFdGez6bu2ZlZdI+rWfIdbQb1KudQp3VGwPtdNCmE0= -github.com/ipfs/go-ipfs-util v0.0.3/go.mod h1:LHzG1a0Ig4G+iZ26UUOMjHd+lfM84LZCrn17xAKWBvs= github.com/ipfs/go-log v1.0.5 h1:2dOuUCB1Z7uoczMWgAyDck5JLb72zHzrMnGnCNNbvY8= github.com/ipfs/go-log v1.0.5/go.mod h1:j0b8ZoR+7+R99LD9jZ6+AJsrzkPbSXbZfGakb5JPtIo= github.com/ipfs/go-log/v2 v2.1.3/go.mod h1:/8d0SH3Su5Ooc31QlL1WysJhvyOTDCjcCZ9Axpmri6g= github.com/ipfs/go-log/v2 v2.5.1 h1:1XdUzF7048prq4aBjDQQ4SL5RxftpRGdXhNRwKSAlcY= github.com/ipfs/go-log/v2 v2.5.1/go.mod h1:prSpmC1Gpllc9UYWxDiZDreBYw7zp4Iqp1kOLU9U5UI= github.com/ipfs/go-test v0.0.4 h1:DKT66T6GBB6PsDFLoO56QZPrOmzJkqU1FZH5C9ySkew= -github.com/ipfs/go-test v0.0.4/go.mod h1:qhIM1EluEfElKKM6fnWxGn822/z9knUGM1+I/OAQNKI= github.com/ipld/go-ipld-prime v0.21.0 h1:n4JmcpOlPDIxBcY037SVfpd1G+Sj1nKZah0m6QH9C2E= github.com/ipld/go-ipld-prime v0.21.0/go.mod h1:3RLqy//ERg/y5oShXXdx5YIp50cFGOanyMctpPjsvxQ= github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= @@ -296,14 +288,12 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= -github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= @@ -328,7 +318,6 @@ github.com/libp2p/go-libp2p-record v0.3.1/go.mod h1:T8itUkLcWQLCYMqtX7Th6r7SexyU github.com/libp2p/go-libp2p-routing-helpers v0.7.4 h1:6LqS1Bzn5CfDJ4tzvP9uwh42IB7TJLNFJA6dEeGBv84= github.com/libp2p/go-libp2p-routing-helpers v0.7.4/go.mod h1:we5WDj9tbolBXOuF1hGOkR+r7Uh1408tQbAKaT5n1LE= github.com/libp2p/go-libp2p-testing v0.12.0 h1:EPvBb4kKMWO29qP4mZGyhVzUyR25dvfUIK5WDu6iPUA= -github.com/libp2p/go-libp2p-testing v0.12.0/go.mod h1:KcGDRXyN7sQCllucn1cOOS+Dmm7ujhfEyXQL5lvkcPg= github.com/libp2p/go-msgio v0.3.0 h1:mf3Z8B1xcFN314sWX+2vOTShIE0Mmn2TXn3YCUQGNj0= github.com/libp2p/go-msgio v0.3.0/go.mod h1:nyRM819GmVaF9LX3l03RMh10QdOroF++NBbxAb0mmDM= github.com/libp2p/go-nat v0.2.0 h1:Tyz+bUFAYqGyJ/ppPPymMGbIgNRH+WqC5QrT5fKrrGk= @@ -399,30 +388,30 @@ github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/n github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU= github.com/multiversx/concurrent-map v0.1.4 h1:hdnbM8VE4b0KYJaGY5yJS2aNIW9TFFsUYwbO0993uPI= github.com/multiversx/concurrent-map v0.1.4/go.mod h1:8cWFRJDOrWHOTNSqgYCUvwT7c7eFQ4U2vKMOp4A/9+o= -github.com/multiversx/mx-chain-communication-go v1.1.2-0.20250218164645-1f6964baffbe h1:cgaCosslTU6qqVJ3r4+xMfNudGjEaPuc6rIVTIxuSqo= -github.com/multiversx/mx-chain-communication-go v1.1.2-0.20250218164645-1f6964baffbe/go.mod h1:Em49dwv2INN13+ledsUYFNxvkdNKxbOgTxXS8gmmHyw= -github.com/multiversx/mx-chain-core-go v1.2.25-0.20250218161123-121084ae9840 h1:rwIljKJpbNLWNBj/oMdcbCKU910JytOXJoBqDYnfres= -github.com/multiversx/mx-chain-core-go v1.2.25-0.20250218161123-121084ae9840/go.mod h1:IO+vspNan+gT0WOHnJ95uvWygiziHZvfXpff6KnxV7g= -github.com/multiversx/mx-chain-crypto-go v1.2.13-0.20250218161752-9482d9a22234 h1:NNI7kYxzsq+4mTPSUJo0cK1+iPxjUX+gRJDaBRwEQ7M= -github.com/multiversx/mx-chain-crypto-go v1.2.13-0.20250218161752-9482d9a22234/go.mod h1:QZAw2bZcOxGQRgYACTrmP8pfTa3NyxENIL+00G6nM5E= -github.com/multiversx/mx-chain-es-indexer-go v1.7.17-0.20250218165903-7923d170f8f0 h1:+wyJShRImKPCvu5vanlnijHluQOgkh70ZTLnp5yQW1s= -github.com/multiversx/mx-chain-es-indexer-go v1.7.17-0.20250218165903-7923d170f8f0/go.mod h1:3O1SPXBD/69tPsDyBLfkGQmKzAp1kfLXcDgyenTIvSQ= -github.com/multiversx/mx-chain-logger-go v1.0.16-0.20250218161408-6a0c19d0da48 h1:Of8RfTBNqJMvfWrDEpAkCAmNjYciM/Hul+yECQMBSHY= -github.com/multiversx/mx-chain-logger-go v1.0.16-0.20250218161408-6a0c19d0da48/go.mod h1:PZMaAr6nhEWgOV04JKBwFNrws0gvHzHW0WaeqnBlGlc= -github.com/multiversx/mx-chain-scenario-go v1.5.1-0.20250218162624-877d8b9870a4 h1:Q/iRXtZ6HhPQ6mV5/KWzg9WeamM90JV/WNQj8uP93ls= -github.com/multiversx/mx-chain-scenario-go v1.5.1-0.20250218162624-877d8b9870a4/go.mod h1:9WV9g7ZOf+7ytXri7KRGInNbJSExUpcZ1BUKbWkJKps= -github.com/multiversx/mx-chain-storage-go v1.0.20-0.20250218162234-85e60acebb43 h1:gmd10vRDOK3QJ7njD/iafV/uaNXl/6QEZf+s+CH9k4c= -github.com/multiversx/mx-chain-storage-go v1.0.20-0.20250218162234-85e60acebb43/go.mod h1:tTVMcXx0UWdMymMv3N8b1D1P1XSQwfyGK6xwMlRoONo= -github.com/multiversx/mx-chain-vm-common-go v1.5.17-0.20250218162215-88938774627c h1:4L3SY1so6MwfmfO7+MGOhGtDxhVW5PtW6JG48sZmHNE= -github.com/multiversx/mx-chain-vm-common-go v1.5.17-0.20250218162215-88938774627c/go.mod h1:NGcFCdOnbpEdk042ixTgD6xavRFQ7ap0z3kBhTXKlDQ= -github.com/multiversx/mx-chain-vm-go v1.5.38-0.20250318180139-42d650e84043 h1:uomqjb4XOchsJs6TtHj0PjLXJpzKiP23ZcL1ZJi4JOM= -github.com/multiversx/mx-chain-vm-go v1.5.38-0.20250318180139-42d650e84043/go.mod h1:ee6MdfII+4DRrfMfEEzrhLiq7r2HZ4oKr/vAGHn8En8= -github.com/multiversx/mx-chain-vm-v1_2-go v1.2.69-0.20250220133402-01591d72f671 h1:xTbDPTaJQ0evqELiXQ4a1pinEAvoE7Y6/cmj4MUjzDA= -github.com/multiversx/mx-chain-vm-v1_2-go v1.2.69-0.20250220133402-01591d72f671/go.mod h1:QbNaHsEseQvrAT81VtbwUTTWPMrbDCDoXRJsY0V+1KU= -github.com/multiversx/mx-chain-vm-v1_3-go v1.3.70-0.20250220133720-4abbb3b36387 h1:B0AMhrWhUIN7HHNHTpfDJHAopTUHFjx8YVMDkb3++WA= -github.com/multiversx/mx-chain-vm-v1_3-go v1.3.70-0.20250220133720-4abbb3b36387/go.mod h1:sVUtPUIiCRxOrCrW9/ygqLN3J1pahbV0PBVY2V7c9cU= -github.com/multiversx/mx-chain-vm-v1_4-go v1.4.99-0.20250220144348-9455d2a4e6e6 h1:o52auTBcVK8WlZ6HWlvpIV+9Uo2M/SSdAZ80tK8CEp8= -github.com/multiversx/mx-chain-vm-v1_4-go v1.4.99-0.20250220144348-9455d2a4e6e6/go.mod h1:BIngPEmFJ0Jt5tG7vkdQ2zrgeidEo+XB6zibImBNre0= +github.com/multiversx/mx-chain-communication-go v1.2.1-0.20250520083403-3f2bad6d5476 h1:Dn73bH1AdG+7+3/FFRfOiivOEvwPyzZUBWWxpk8QVxc= +github.com/multiversx/mx-chain-communication-go v1.2.1-0.20250520083403-3f2bad6d5476/go.mod h1:99+FW6f7X0Ri5tph+2l2GaDVrdej1do89exkfh7gilE= +github.com/multiversx/mx-chain-core-go v1.3.2-0.20250520074139-18b645ad397a h1:dhCobNEcBdvutX+0UYF/l86oLVO9iUiUeF3sLFa9qhE= +github.com/multiversx/mx-chain-core-go v1.3.2-0.20250520074139-18b645ad397a/go.mod h1:IO+vspNan+gT0WOHnJ95uvWygiziHZvfXpff6KnxV7g= +github.com/multiversx/mx-chain-crypto-go v1.2.13-0.20250520075055-8ab2a164945d h1:NI5uKpkwP5XZu9gtDiWxmbbb07T9hXegPist17WAzY4= +github.com/multiversx/mx-chain-crypto-go v1.2.13-0.20250520075055-8ab2a164945d/go.mod h1:yekQt4uB5LYXtimbhpdUbnFexjucWrQG/t+AX55bdM8= +github.com/multiversx/mx-chain-es-indexer-go v1.8.2-0.20250520083544-09915e4d9bae h1:3I8l+SE/unbOhc3QcLmlLX1aCYssDI+oLDavRUw69LY= +github.com/multiversx/mx-chain-es-indexer-go v1.8.2-0.20250520083544-09915e4d9bae/go.mod h1:rU+8opckju7oVlu+UmD3esSV0jvP8Lu7E1raTRXDyv8= +github.com/multiversx/mx-chain-logger-go v1.0.16-0.20250520074859-b2faf3c90273 h1:1I2CgGDAMINxrKI6yzSP/Y6Wow2YUmqegUXcltpGXQA= +github.com/multiversx/mx-chain-logger-go v1.0.16-0.20250520074859-b2faf3c90273/go.mod h1:M/uRv1kpmkzxS5HsgofdRcOHzzvagD7nTmFqiPKt89U= +github.com/multiversx/mx-chain-scenario-go v1.5.1-0.20250520075713-734e46b4c66d h1:BEJHmDMqoDzgNWx/jWn191WoSHLCIQmKszyJGEYWlyA= +github.com/multiversx/mx-chain-scenario-go v1.5.1-0.20250520075713-734e46b4c66d/go.mod h1:/bgycTrJGk6n6VlgW+onXRQKiCM3GFM+INENvj/PRgU= +github.com/multiversx/mx-chain-storage-go v1.0.20-0.20250520075958-65fd4c7bcaae h1:wgiIjyoynLQPs8QkEOXjAU8AoZlLA5w4rWPykLS2BSo= +github.com/multiversx/mx-chain-storage-go v1.0.20-0.20250520075958-65fd4c7bcaae/go.mod h1:uiXDLvpznajMubl+OBhODo6jmtwo8kyUF9iujEhOIgI= +github.com/multiversx/mx-chain-vm-common-go v1.5.17-0.20250520075408-c94bee9ee163 h1:I6WEqu3ysY41nRV7mUvdCsKyuBZlHyKngIjW4ncEcLI= +github.com/multiversx/mx-chain-vm-common-go v1.5.17-0.20250520075408-c94bee9ee163/go.mod h1:HlpJgCTYVvHE1nrEJLIsR/AJx0gqzg3m+qdJwf7jOjU= +github.com/multiversx/mx-chain-vm-go v1.5.41-0.20250520080530-2838146363b4 h1:FTgwsqn/PXEooGfO4zjEn+oH014A6agOXjeMDllneL8= +github.com/multiversx/mx-chain-vm-go v1.5.41-0.20250520080530-2838146363b4/go.mod h1:vfFEw1qAmR1mUU8/p2EKu2woKS+o9W8wSqi3muEqRds= +github.com/multiversx/mx-chain-vm-v1_2-go v1.2.69-0.20250520080927-410c413d962f h1:Sg1SZWm90IeliVPce3w0CtLjr+a+mcWAFVHW2VGR0nA= +github.com/multiversx/mx-chain-vm-v1_2-go v1.2.69-0.20250520080927-410c413d962f/go.mod h1:aeGXPTVkUDsPcHwuSer2VXEnMow7iofEvuDGPd47Cj8= +github.com/multiversx/mx-chain-vm-v1_3-go v1.3.70-0.20250520081414-edf8b75e054d h1:DFYWypkQs7BYepsLPz/IIN8cGsEf4+fWM9L0a6mnfKU= +github.com/multiversx/mx-chain-vm-v1_3-go v1.3.70-0.20250520081414-edf8b75e054d/go.mod h1:SQbB1KY4qt2HngdqcBvu9wV0pERXKP5eP+rRxciyL84= +github.com/multiversx/mx-chain-vm-v1_4-go v1.4.99-0.20250520081749-516b5ae0e49c h1:mazIPF6FgwBImN8YNorP3kjG3nygbDXbNkuBf5ILUvM= +github.com/multiversx/mx-chain-vm-v1_4-go v1.4.99-0.20250520081749-516b5ae0e49c/go.mod h1:Fi6zq++lc9cFhLVKULa6HVVD2P4Ya3GD1Lua60u/rpY= github.com/multiversx/mx-components-big-int v1.0.1-0.20250218162530-b4e4d7442408 h1:FbpVQJg14ry25DiBBZIvBoKiSrTHWkUSlYMOOW/iQJQ= github.com/multiversx/mx-components-big-int v1.0.1-0.20250218162530-b4e4d7442408/go.mod h1:kcWw7hDe6cSz1wcBAqj/6sFH6ouSPsNeH9P7XlpZRcw= github.com/multiversx/protobuf v1.3.2 h1:RaNkxvGTGbA0lMcnHAN24qE1G1i+Xs5yHA6MDvQ4mSM= @@ -436,7 +425,6 @@ github.com/nsf/termbox-go v0.0.0-20190121233118-02980233997d/go.mod h1:IuKpRQcYE github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/nxadm/tail v1.4.11 h1:8feyoE3OzPrcshW5/MJ4sGESc5cqmGkGCWlco4l0bqY= -github.com/nxadm/tail v1.4.11/go.mod h1:OTaG3NK980DZzxbRq6lEuzgU+mug70nY11sMd4JXXHc= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= @@ -454,7 +442,6 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/onsi/gomega v1.17.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/onsi/gomega v1.19.0/go.mod h1:LY+I3pBVzYsTBU1AnDwOSxaYi9WoWiqgwooUqq9yPro= github.com/onsi/gomega v1.34.2 h1:pNCwDkzrsv7MS9kpaQvVb1aVLahQXyJ/Tv5oAZMI3i8= -github.com/onsi/gomega v1.34.2/go.mod h1:v1xfxRgk0KIsG+QOdm7p8UosrOzPYRo60fd3B/1Dukc= github.com/opencontainers/runtime-spec v1.0.2/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opencontainers/runtime-spec v1.2.0 h1:z97+pHb3uELt/yiAWD691HNHQIF07bE7dzrbT927iTk= github.com/opencontainers/runtime-spec v1.2.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= @@ -543,7 +530,6 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= -github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= @@ -664,7 +650,6 @@ go.uber.org/fx v1.23.0 h1:lIr/gYWQGfTwGcSXWXu4vP5Ws6iqnNEIY+F/aFzCKTg= go.uber.org/fx v1.23.0/go.mod h1:o/D9n+2mLP6v1EG+qsdT1O8wKopYAsqZasju97SDFCU= go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= -go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= @@ -829,7 +814,6 @@ golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/heartbeat/monitor/monitor_test.go b/heartbeat/monitor/monitor_test.go index 83ae428fbee..02524882220 100644 --- a/heartbeat/monitor/monitor_test.go +++ b/heartbeat/monitor/monitor_test.go @@ -9,19 +9,21 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/heartbeat/data" "github.com/multiversx/mx-chain-go/heartbeat/mock" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" - "github.com/stretchr/testify/assert" ) func createMockHeartbeatV2MonitorArgs() ArgHeartbeatV2Monitor { return ArgHeartbeatV2Monitor{ - Cache: testscommon.NewCacherMock(), + Cache: cache.NewCacherMock(), PubKeyConverter: &testscommon.PubkeyConverterMock{}, Marshaller: &marshallerMock.MarshalizerMock{}, MaxDurationPeerUnresponsive: time.Second * 3, diff --git a/heartbeat/processor/peerAuthenticationRequestsProcessor_test.go b/heartbeat/processor/peerAuthenticationRequestsProcessor_test.go index 39e21d9eb80..958ee50879b 100644 --- a/heartbeat/processor/peerAuthenticationRequestsProcessor_test.go +++ b/heartbeat/processor/peerAuthenticationRequestsProcessor_test.go @@ -14,18 +14,20 @@ import ( mxAtomic "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/random" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockArgPeerAuthenticationRequestsProcessor() ArgPeerAuthenticationRequestsProcessor { return ArgPeerAuthenticationRequestsProcessor{ RequestHandler: &testscommon.RequestHandlerStub{}, NodesCoordinator: &shardingMocks.NodesCoordinatorStub{}, - PeerAuthenticationPool: &testscommon.CacherMock{}, + PeerAuthenticationPool: &cache.CacherMock{}, ShardId: 0, Epoch: 0, MinPeersThreshold: 0.8, @@ -200,7 +202,7 @@ func TestPeerAuthenticationRequestsProcessor_startRequestingMessages(t *testing. }, } - args.PeerAuthenticationPool = &testscommon.CacherStub{ + args.PeerAuthenticationPool = &cache.CacherStub{ KeysCalled: func() [][]byte { return providedEligibleKeysMap[0] }, @@ -236,7 +238,7 @@ func TestPeerAuthenticationRequestsProcessor_isThresholdReached(t *testing.T) { args := createMockArgPeerAuthenticationRequestsProcessor() args.MinPeersThreshold = 0.6 counter := uint32(0) - args.PeerAuthenticationPool = &testscommon.CacherStub{ + args.PeerAuthenticationPool = &cache.CacherStub{ KeysCalled: func() [][]byte { var keys = make([][]byte, 0) switch atomic.LoadUint32(&counter) { @@ -323,7 +325,7 @@ func TestPeerAuthenticationRequestsProcessor_goRoutineIsWorkingAndCloseShouldSto }, } keysCalled := &mxAtomic.Flag{} - args.PeerAuthenticationPool = &testscommon.CacherStub{ + args.PeerAuthenticationPool = &cache.CacherStub{ KeysCalled: func() [][]byte { keysCalled.SetValue(true) return make([][]byte, 0) diff --git a/heartbeat/status/metricsUpdater_test.go b/heartbeat/status/metricsUpdater_test.go index 645f4edb0dd..c9cfd4e16df 100644 --- a/heartbeat/status/metricsUpdater_test.go +++ b/heartbeat/status/metricsUpdater_test.go @@ -8,18 +8,19 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/heartbeat/data" "github.com/multiversx/mx-chain-go/heartbeat/mock" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) func createMockArgsMetricsUpdater() ArgsMetricsUpdater { return ArgsMetricsUpdater{ - PeerAuthenticationCacher: testscommon.NewCacherMock(), + PeerAuthenticationCacher: cache.NewCacherMock(), HeartbeatMonitor: &mock.HeartbeatMonitorStub{}, HeartbeatSenderInfoProvider: &mock.HeartbeatSenderInfoProviderStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, diff --git a/integrationTests/benchmarks/loadFromTrie_test.go b/integrationTests/benchmarks/loadFromTrie_test.go index 866368e7cf4..8b2d2736b1a 100644 --- a/integrationTests/benchmarks/loadFromTrie_test.go +++ b/integrationTests/benchmarks/loadFromTrie_test.go @@ -9,8 +9,8 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing/blake2b" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" - disabledStatistics "github.com/multiversx/mx-chain-go/common/statistics/disabled" "github.com/multiversx/mx-chain-go/common/holders" + disabledStatistics "github.com/multiversx/mx-chain-go/common/statistics/disabled" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/database" diff --git a/integrationTests/chainSimulator/interface.go b/integrationTests/chainSimulator/interface.go index 7aba83c5103..9be64756c27 100644 --- a/integrationTests/chainSimulator/interface.go +++ b/integrationTests/chainSimulator/interface.go @@ -3,12 +3,11 @@ package chainSimulator import ( "math/big" - "github.com/multiversx/mx-chain-go/node/chainSimulator/dtos" - "github.com/multiversx/mx-chain-go/node/chainSimulator/process" - "github.com/multiversx/mx-chain-core-go/data/api" "github.com/multiversx/mx-chain-core-go/data/transaction" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/node/chainSimulator/dtos" + "github.com/multiversx/mx-chain-go/node/chainSimulator/process" ) // ChainSimulator defines the operations for an entity that can simulate operations of a chain diff --git a/integrationTests/chainSimulator/mempool/mempool_test.go b/integrationTests/chainSimulator/mempool/mempool_test.go index 704be8e40fb..af2518e841e 100644 --- a/integrationTests/chainSimulator/mempool/mempool_test.go +++ b/integrationTests/chainSimulator/mempool/mempool_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/node/chainSimulator/configs" "github.com/multiversx/mx-chain-go/storage" - "github.com/stretchr/testify/require" ) func TestMempoolWithChainSimulator_Selection(t *testing.T) { @@ -324,7 +325,8 @@ func TestMempoolWithChainSimulator_Eviction(t *testing.T) { Signature: []byte("signature"), }) - time.Sleep(2 * time.Second) + // Allow the eviction to complete (even if it's quite fast). + time.Sleep(3 * time.Second) expectedNumTransactionsInPool := 300_000 + 1 + 1 - int(storage.TxPoolSourceMeNumItemsToPreemptivelyEvict) require.Equal(t, expectedNumTransactionsInPool, getNumTransactionsInPool(simulator, shard)) diff --git a/integrationTests/chainSimulator/mempool/testutils_test.go b/integrationTests/chainSimulator/mempool/testutils_test.go index 3d4a0afd5f7..a86494f9fa8 100644 --- a/integrationTests/chainSimulator/mempool/testutils_test.go +++ b/integrationTests/chainSimulator/mempool/testutils_test.go @@ -24,8 +24,8 @@ import ( var ( oneEGLD = big.NewInt(1000000000000000000) oneQuarterOfEGLD = big.NewInt(250000000000000000) - durationWaitAfterSendMany = 1500 * time.Millisecond - durationWaitAfterSendSome = 50 * time.Millisecond + durationWaitAfterSendMany = 3000 * time.Millisecond + durationWaitAfterSendSome = 300 * time.Millisecond ) func startChainSimulator(t *testing.T, alterConfigsFunction func(cfg *config.Configs)) testsChainSimulator.ChainSimulator { diff --git a/integrationTests/chainSimulator/relayedTx/relayedTx_test.go b/integrationTests/chainSimulator/relayedTx/relayedTx_test.go index b10d9a783c3..ca994de8b3e 100644 --- a/integrationTests/chainSimulator/relayedTx/relayedTx_test.go +++ b/integrationTests/chainSimulator/relayedTx/relayedTx_test.go @@ -13,6 +13,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" apiData "github.com/multiversx/mx-chain-core-go/data/api" "github.com/multiversx/mx-chain-core-go/data/transaction" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" testsChainSimulator "github.com/multiversx/mx-chain-go/integrationTests/chainSimulator" @@ -25,8 +28,6 @@ import ( "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/vm" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/require" ) const ( @@ -56,10 +57,10 @@ func TestRelayedV3WithChainSimulator(t *testing.T) { t.Skip("this is not a short test") } + t.Run("successful intra shard guarded move balance", testRelayedV3MoveBalance(0, 0, false, true)) t.Run("sender == relayer move balance should consume fee", testRelayedV3RelayedBySenderMoveBalance()) t.Run("receiver == relayer move balance should consume fee", testRelayedV3RelayedByReceiverMoveBalance()) t.Run("successful intra shard move balance", testRelayedV3MoveBalance(0, 0, false, false)) - t.Run("successful intra shard guarded move balance", testRelayedV3MoveBalance(0, 0, false, true)) t.Run("successful intra shard move balance with extra gas", testRelayedV3MoveBalance(0, 0, true, false)) t.Run("successful cross shard move balance", testRelayedV3MoveBalance(0, 1, false, false)) t.Run("successful cross shard guarded move balance", testRelayedV3MoveBalance(0, 1, false, true)) @@ -104,7 +105,6 @@ func testRelayedV3MoveBalance( guardedTx bool, ) func(t *testing.T) { return func(t *testing.T) { - providedActivationEpoch := uint32(1) alterConfigsFunc := func(cfg *config.Configs) { cfg.EpochConfig.EnableEpochs.FixRelayedBaseCostEnableEpoch = providedActivationEpoch diff --git a/integrationTests/chainSimulator/rewards/rewards_test.go b/integrationTests/chainSimulator/rewards/rewards_test.go new file mode 100644 index 00000000000..f298e9d9889 --- /dev/null +++ b/integrationTests/chainSimulator/rewards/rewards_test.go @@ -0,0 +1,389 @@ +package rewards + +import ( + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "math/big" + "os" + "path" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + apiCore "github.com/multiversx/mx-chain-core-go/data/api" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/multiversx/mx-chain-go/common" + csUtils "github.com/multiversx/mx-chain-go/integrationTests/chainSimulator" + "github.com/multiversx/mx-chain-go/node/chainSimulator" + "github.com/multiversx/mx-chain-go/node/chainSimulator/components/api" + "github.com/multiversx/mx-chain-go/node/chainSimulator/configs" + "github.com/multiversx/mx-chain-go/node/chainSimulator/dtos" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" + "github.com/stretchr/testify/require" +) + +const ( + defaultPathToInitialConfig = "../../../cmd/node/config/" + + moveBalanceGasLimit = 50_000 + gasPrice = 1_000_000_000 +) + +func TestRewardsAfterAndromedaWithTxs(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + startTime := time.Now().Unix() + roundDurationInMillis := uint64(6000) + roundsPerEpoch := core.OptionalUint64{ + HasValue: true, + Value: 200, + } + + numOfShards := uint32(3) + + tempDir := t.TempDir() + cs, err := chainSimulator.NewChainSimulator(chainSimulator.ArgsChainSimulator{ + BypassTxSignatureCheck: true, + TempDir: tempDir, + PathToInitialConfig: defaultPathToInitialConfig, + NumOfShards: numOfShards, + GenesisTimestamp: startTime, + RoundDurationInMillis: roundDurationInMillis, + RoundsPerEpoch: roundsPerEpoch, + ApiInterface: api.NewNoApiInterface(), + MinNodesPerShard: 3, + MetaChainMinNodes: 3, + }) + require.Nil(t, err) + require.NotNil(t, cs) + defer cs.Close() + + targetEpoch := 9 + for i := 0; i < targetEpoch; i++ { + err = cs.ForceChangeOfEpoch() + require.Nil(t, err) + } + + err = cs.GenerateBlocks(1) + require.Nil(t, err) + + targetShardID := uint32(0) + numTxs := 10_000 + txs := generateMoveBalance(t, cs, numTxs, targetShardID, targetShardID) + + results, err := cs.SendTxsAndGenerateBlocksTilAreExecuted(txs, 10) + require.Nil(t, err) + + blockWithTxsHash := results[0].BlockHash + blocksWithTxs, err := cs.GetNodeHandler(targetShardID).GetFacadeHandler().GetBlockByHash(blockWithTxsHash, apiCore.BlockQueryOptions{}) + require.Nil(t, err) + + prevRandSeed, _ := hex.DecodeString(blocksWithTxs.PrevRandSeed) + leader, _, err := cs.GetNodeHandler(targetShardID).GetProcessComponents().NodesCoordinator().ComputeConsensusGroup(prevRandSeed, blocksWithTxs.Round, 0, blocksWithTxs.Epoch) + require.Nil(t, err) + + nodesSetupFile := path.Join(tempDir, "config", "nodesSetup.json") + validators, err := readValidatorsAndOwners(nodesSetupFile) + require.Nil(t, err) + + err = cs.GenerateBlocks(210) + require.Nil(t, err) + + metaBlock := getLastStartOfEpochBlock(t, cs, core.MetachainShardId) + require.NotNil(t, metaBlock) + + leaderEncoded, _ := cs.GetNodeHandler(0).GetCoreComponents().ValidatorPubKeyConverter().Encode(leader.PubKey()) + leaderOwnerBlockWithTxs := validators[leaderEncoded] + + var anotherOwner string + found := false + for _, address := range validators { + if address != leaderOwnerBlockWithTxs { + anotherOwner = address + found = true + } + } + require.True(t, found) + + rewardTxForLeader := getRewardTxForAddress(metaBlock, leaderOwnerBlockWithTxs) + require.NotNil(t, rewardTxForLeader) + anotherRewardTx := getRewardTxForAddress(metaBlock, anotherOwner) + require.NotNil(t, anotherRewardTx) + + rewardTxValueLeaderWithTxs, _ := big.NewInt(0).SetString(rewardTxForLeader.Value, 10) + rewardTxValueAnotherOwner, _ := big.NewInt(0).SetString(anotherRewardTx.Value, 10) + + coordinator := cs.GetNodeHandler(0).GetProcessComponents().NodesCoordinator() + + rewardsPerShard, err := computeRewardsForShards(metaBlock, coordinator, validators) + require.Nil(t, err) + + // diff should be equal with 0.1 * moveBalanceCost * num transactions + // diff = 0.1 * move balance gas limit * gas price * num transactions + diff := big.NewInt(0).Mul(big.NewInt(moveBalanceGasLimit*0.1), big.NewInt(gasPrice)) + diff.Mul(diff, big.NewInt(int64(numTxs))) + + // check reward tx value + require.Equal(t, rewardTxValueLeaderWithTxs, big.NewInt(0).Add(rewardTxValueAnotherOwner, diff)) + + // rewards for target shard should be rewards for another shard + diff + require.Equal(t, rewardsPerShard[targetShardID], big.NewInt(0).Add(rewardsPerShard[core.MetachainShardId], diff)) +} + +func getRewardTxForAddress(block *apiCore.Block, address string) *transaction.ApiTransactionResult { + for _, mb := range block.MiniBlocks { + for _, tx := range mb.Transactions { + if tx.Receiver == address { + return tx + } + } + } + + return nil +} + +func generateMoveBalance(t *testing.T, cs chainSimulator.ChainSimulator, numTxs int, senderShardID, receiverShardID uint32) []*transaction.Transaction { + numSenders := (numTxs + common.MaxTxNonceDeltaAllowed - 1) / common.MaxTxNonceDeltaAllowed + tenEGLD := big.NewInt(0).Mul(csUtils.OneEGLD, big.NewInt(10)) + + senders := make([]dtos.WalletAddress, 0, numSenders) + for i := 0; i < numSenders; i++ { + + sender, err := cs.GenerateAndMintWalletAddress(senderShardID, tenEGLD) + require.Nil(t, err) + + senders = append(senders, sender) + } + + err := cs.GenerateBlocks(1) + require.Nil(t, err) + + txs := make([]*transaction.Transaction, 0, numTxs) + for i := 0; i < numSenders-1; i++ { + txs = append(txs, generateMoveBalanceTxs(t, cs, senders[i], common.MaxTxNonceDeltaAllowed, receiverShardID)...) + } + + lastBatchSize := numTxs % common.MaxTxNonceDeltaAllowed + if lastBatchSize == 0 { + lastBatchSize = common.MaxTxNonceDeltaAllowed + } + + txs = append(txs, generateMoveBalanceTxs(t, cs, senders[len(senders)-1], lastBatchSize, receiverShardID)...) + + return txs +} + +func generateMoveBalanceTxs(t *testing.T, cs chainSimulator.ChainSimulator, sender dtos.WalletAddress, numTxs int, receiverShardID uint32) []*transaction.Transaction { + senderShardID := cs.GetNodeHandler(core.MetachainShardId).GetShardCoordinator().ComputeId(sender.Bytes) + + res, _, err := cs.GetNodeHandler(senderShardID).GetFacadeHandler().GetAccount(sender.Bech32, apiCore.AccountQueryOptions{}) + require.Nil(t, err) + + txs := make([]*transaction.Transaction, numTxs) + initialNonce := res.Nonce + for i := 0; i < numTxs; i++ { + rcv := cs.GenerateAddressInShard(receiverShardID) + + txs[i] = &transaction.Transaction{ + Nonce: initialNonce + uint64(i), + Value: big.NewInt(1), + RcvAddr: rcv.Bytes, + SndAddr: sender.Bytes, + GasPrice: gasPrice, + GasLimit: moveBalanceGasLimit, + ChainID: []byte(configs.ChainID), + Version: 2, + Signature: []byte("sig"), + } + } + + return txs +} + +func TestRewardsTxsAfterAndromeda(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + startTime := time.Now().Unix() + roundDurationInMillis := uint64(6000) + roundsPerEpoch := core.OptionalUint64{ + HasValue: true, + Value: 200, + } + + numOfShards := uint32(3) + + tempDir := t.TempDir() + cs, err := chainSimulator.NewChainSimulator(chainSimulator.ArgsChainSimulator{ + BypassTxSignatureCheck: true, + TempDir: tempDir, + PathToInitialConfig: defaultPathToInitialConfig, + NumOfShards: numOfShards, + GenesisTimestamp: startTime, + RoundDurationInMillis: roundDurationInMillis, + RoundsPerEpoch: roundsPerEpoch, + ApiInterface: api.NewNoApiInterface(), + MinNodesPerShard: 3, + MetaChainMinNodes: 3, + }) + require.Nil(t, err) + require.NotNil(t, cs) + defer cs.Close() + + targetEpoch := 9 + for i := 0; i < targetEpoch; i++ { + err = cs.ForceChangeOfEpoch() + require.Nil(t, err) + } + + err = cs.GenerateBlocks(210) + require.Nil(t, err) + + nodesSetupFile := path.Join(tempDir, "config", "nodesSetup.json") + validators, err := readValidatorsAndOwners(nodesSetupFile) + require.Nil(t, err) + + metaBlock := getLastStartOfEpochBlock(t, cs, core.MetachainShardId) + require.NotNil(t, metaBlock) + + coordinator := cs.GetNodeHandler(0).GetProcessComponents().NodesCoordinator() + + rewardsPerShard, err := computeRewardsForShards(metaBlock, coordinator, validators) + require.Nil(t, err) + + for shardID, reward := range rewardsPerShard { + fmt.Printf("rewards on shard %d: %s\n", shardID, reward.String()) + } + + require.True(t, allValuesEqual(rewardsPerShard)) +} + +func getLastStartOfEpochBlock(t *testing.T, cs chainSimulator.ChainSimulator, shardID uint32) *apiCore.Block { + metachainHandler := cs.GetNodeHandler(shardID).GetFacadeHandler() + + networkStatus, err := metachainHandler.StatusMetrics().NetworkMetrics() + require.Nil(t, err) + + epochStartBlocKNonce, ok := networkStatus[common.MetricNonceAtEpochStart].(uint64) + require.True(t, ok) + + metaBlock, err := metachainHandler.GetBlockByNonce(epochStartBlocKNonce, apiCore.BlockQueryOptions{ + WithTransactions: true, + }) + require.Nil(t, err) + + return metaBlock +} + +func computeRewardsForShards( + metaBlock *apiCore.Block, + coordinator nodesCoordinator.NodesCoordinator, + validators map[string]string, +) (map[uint32]*big.Int, error) { + shards := []uint32{0, 1, 2, core.MetachainShardId} + rewardsPerShard := make(map[uint32]*big.Int) + + for _, shardID := range shards { + rewardsPerShard[shardID] = big.NewInt(0) // Initialize reward entry + err := computeRewardsForShard(metaBlock, coordinator, validators, shardID, rewardsPerShard) + if err != nil { + return nil, err + } + } + + return rewardsPerShard, nil +} + +func computeRewardsForShard(metaBlock *apiCore.Block, + coordinator nodesCoordinator.NodesCoordinator, + validators map[string]string, + shardID uint32, + rewardsPerShard map[uint32]*big.Int, +) error { + validatorsPerShard, _ := coordinator.GetAllEligibleValidatorsPublicKeysForShard(8, shardID) + + for _, validator := range validatorsPerShard { + owner, exists := validators[hex.EncodeToString([]byte(validator))] + if !exists { + continue + } + err := accumulateShardRewards(metaBlock, shardID, owner, rewardsPerShard) + if err != nil { + return err + } + } + + return nil +} + +func accumulateShardRewards(metaBlock *apiCore.Block, shardID uint32, owner string, rewardsPerShard map[uint32]*big.Int) error { + var firstValue *big.Int + for _, mb := range metaBlock.MiniBlocks { + if mb.Type != block.RewardsBlock.String() { + continue + } + + for _, tx := range mb.Transactions { + if tx.Receiver != owner { + continue + } + + valueBig, _ := new(big.Int).SetString(tx.Value, 10) + if firstValue == nil { + firstValue = valueBig + } + if valueBig.Cmp(firstValue) != 0 { + return errors.New("different values in rewards transactions") + } + + rewardsPerShard[shardID].Add(rewardsPerShard[shardID], valueBig) + } + } + + return nil +} + +func readValidatorsAndOwners(filePath string) (map[string]string, error) { + file, err := os.ReadFile(filePath) + if err != nil { + return nil, err + } + + var nodesSetup struct { + InitialNodes []struct { + PubKey string `json:"pubkey"` + Address string `json:"address"` + } `json:"initialNodes"` + } + + err = json.Unmarshal(file, &nodesSetup) + if err != nil { + return nil, err + } + + validators := make(map[string]string) + for _, node := range nodesSetup.InitialNodes { + validators[node.PubKey] = node.Address + } + + return validators, nil +} + +func allValuesEqual(m map[uint32]*big.Int) bool { + if len(m) == 0 { + return true + } + expectedValue := m[0] + for _, v := range m { + if expectedValue.Cmp(v) != 0 { + return false + } + } + return true +} diff --git a/integrationTests/chainSimulator/staking/common.go b/integrationTests/chainSimulator/staking/common.go index 4de97df500e..e9e8bee3643 100644 --- a/integrationTests/chainSimulator/staking/common.go +++ b/integrationTests/chainSimulator/staking/common.go @@ -5,13 +5,12 @@ import ( "math/big" "testing" + "github.com/multiversx/mx-chain-core-go/core" chainSimulatorIntegrationTests "github.com/multiversx/mx-chain-go/integrationTests/chainSimulator" "github.com/multiversx/mx-chain-go/node/chainSimulator/dtos" chainSimulatorProcess "github.com/multiversx/mx-chain-go/node/chainSimulator/process" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/vm" - - "github.com/multiversx/mx-chain-core-go/core" "github.com/stretchr/testify/require" ) diff --git a/integrationTests/chainSimulator/staking/jail/jail_test.go b/integrationTests/chainSimulator/staking/jail/jail_test.go index bb449da993f..9f431d1f924 100644 --- a/integrationTests/chainSimulator/staking/jail/jail_test.go +++ b/integrationTests/chainSimulator/staking/jail/jail_test.go @@ -22,10 +22,9 @@ import ( ) const ( - stakingV4JailUnJailStep1EnableEpoch = 5 + stakingV4JailUnJailStep1EnableEpoch = 9 defaultPathToInitialConfig = "../../../../cmd/node/config/" - - epochWhenNodeIsJailed = 4 + epochWhenNodeIsJailed = 8 ) // Test description @@ -40,19 +39,19 @@ func TestChainSimulator_ValidatorJailUnJail(t *testing.T) { } t.Run("staking ph 4 is not active", func(t *testing.T) { - testChainSimulatorJailAndUnJail(t, 4, "new") + testChainSimulatorJailAndUnJail(t, 8, "new") }) t.Run("staking ph 4 step 1 active", func(t *testing.T) { - testChainSimulatorJailAndUnJail(t, 5, "auction") + testChainSimulatorJailAndUnJail(t, 9, "auction") }) t.Run("staking ph 4 step 2 active", func(t *testing.T) { - testChainSimulatorJailAndUnJail(t, 6, "auction") + testChainSimulatorJailAndUnJail(t, 10, "auction") }) t.Run("staking ph 4 step 3 active", func(t *testing.T) { - testChainSimulatorJailAndUnJail(t, 7, "auction") + testChainSimulatorJailAndUnJail(t, 11, "auction") }) } @@ -75,10 +74,11 @@ func testChainSimulatorJailAndUnJail(t *testing.T, targetEpoch int32, nodeStatus RoundDurationInMillis: roundDurationInMillis, RoundsPerEpoch: roundsPerEpoch, ApiInterface: api.NewNoApiInterface(), - MinNodesPerShard: 2, - MetaChainMinNodes: 2, + MinNodesPerShard: 4, + MetaChainMinNodes: 4, AlterConfigsFunction: func(cfg *config.Configs) { configs.SetStakingV4ActivationEpochs(cfg, stakingV4JailUnJailStep1EnableEpoch) + cfg.EpochConfig.EnableEpochs.AndromedaEnableEpoch = 100 newNumNodes := cfg.SystemSCConfig.StakingSystemSCConfig.MaxNumberOfNodesForStake + 8 // 8 nodes until new nodes will be placed on queue configs.SetMaxNumberOfNodesInConfigs(cfg, uint32(newNumNodes), 0, numOfShards) configs.SetQuickJailRatingConfig(cfg) @@ -178,8 +178,8 @@ func TestChainSimulator_FromQueueToAuctionList(t *testing.T) { RoundDurationInMillis: roundDurationInMillis, RoundsPerEpoch: roundsPerEpoch, ApiInterface: api.NewNoApiInterface(), - MinNodesPerShard: 3, - MetaChainMinNodes: 3, + MinNodesPerShard: 4, + MetaChainMinNodes: 4, AlterConfigsFunction: func(cfg *config.Configs) { configs.SetStakingV4ActivationEpochs(cfg, stakingV4JailUnJailStep1EnableEpoch) configs.SetQuickJailRatingConfig(cfg) @@ -255,3 +255,65 @@ func TestChainSimulator_FromQueueToAuctionList(t *testing.T) { staking.CheckValidatorStatus(t, cs, blsKeys[0], string(common.InactiveList)) } + +func TestJailNodes(t *testing.T) { + startTime := time.Now().Unix() + roundDurationInMillis := uint64(6000) + roundsPerEpoch := core.OptionalUint64{ + HasValue: true, + Value: 20, + } + + numOfShards := uint32(3) + + cs, err := chainSimulator.NewChainSimulator(chainSimulator.ArgsChainSimulator{ + BypassTxSignatureCheck: true, + TempDir: t.TempDir(), + PathToInitialConfig: defaultPathToInitialConfig, + NumOfShards: numOfShards, + GenesisTimestamp: startTime, + RoundDurationInMillis: roundDurationInMillis, + RoundsPerEpoch: roundsPerEpoch, + ApiInterface: api.NewNoApiInterface(), + MinNodesPerShard: 4, + MetaChainMinNodes: 4, + NumNodesWaitingListMeta: 1, + NumNodesWaitingListShard: 1, + AlterConfigsFunction: func(cfg *config.Configs) { + configs.SetQuickJailRatingConfig(cfg) + newNumNodes := cfg.SystemSCConfig.StakingSystemSCConfig.MaxNumberOfNodesForStake + 1 + configs.SetMaxNumberOfNodesInConfigs(cfg, uint32(newNumNodes), 0, numOfShards) + }, + }) + require.Nil(t, err) + require.NotNil(t, cs) + defer cs.Close() + + err = cs.GenerateBlocks(30) + require.Nil(t, err) + + _, blsKeys, err := chainSimulator.GenerateBlsPrivateKeys(1) + require.Nil(t, err) + + mintValue := big.NewInt(0).Mul(chainSimulatorIntegrationTests.OneEGLD, big.NewInt(6000)) + walletAddress, err := cs.GenerateAndMintWalletAddress(core.AllShardId, mintValue) + require.Nil(t, err) + + for i := 0; i < 10; i++ { + err = cs.ForceChangeOfEpoch() + require.Nil(t, err) + } + + txDataField := fmt.Sprintf("stake@01@%s@%s", blsKeys[0], staking.MockBLSSignature) + txStake := chainSimulatorIntegrationTests.GenerateTransaction(walletAddress.Bytes, 0, vm.ValidatorSCAddress, chainSimulatorIntegrationTests.MinimumStakeValue, txDataField, staking.GasLimitForStakeOperation) + stakeTx, err := cs.SendTxAndGenerateBlockTilTxIsExecuted(txStake, staking.MaxNumOfBlockToGenerateWhenExecutingTx) + require.Nil(t, err) + require.NotNil(t, stakeTx) + + err = cs.GenerateBlocks(200) + require.Nil(t, err) + + decodedBLSKey0, _ := hex.DecodeString(blsKeys[0]) + status := staking.GetBLSKeyStatus(t, cs.GetNodeHandler(core.MetachainShardId), decodedBLSKey0) + require.Equal(t, "jailed", status) +} diff --git a/integrationTests/chainSimulator/staking/stakingProvider/delegation_test.go b/integrationTests/chainSimulator/staking/stakingProvider/delegation_test.go index bcc63c71d28..c9b20b2e218 100644 --- a/integrationTests/chainSimulator/staking/stakingProvider/delegation_test.go +++ b/integrationTests/chainSimulator/staking/stakingProvider/delegation_test.go @@ -94,6 +94,8 @@ func TestChainSimulator_MakeNewContractFromValidatorData(t *testing.T) { cfg.EpochConfig.EnableEpochs.StakingV4Step3EnableEpoch = 102 cfg.EpochConfig.EnableEpochs.MaxNodesChangeEnableEpoch[2].EpochEnable = 102 + + cfg.EpochConfig.EnableEpochs.AndromedaEnableEpoch = 1 }, }) require.Nil(t, err) @@ -139,6 +141,8 @@ func TestChainSimulator_MakeNewContractFromValidatorData(t *testing.T) { cfg.EpochConfig.EnableEpochs.StakingV4Step3EnableEpoch = 102 cfg.EpochConfig.EnableEpochs.MaxNodesChangeEnableEpoch[2].EpochEnable = 102 + + cfg.EpochConfig.EnableEpochs.AndromedaEnableEpoch = 1 }, }) require.Nil(t, err) diff --git a/integrationTests/chainSimulator/vm/egldMultiTransfer_test.go b/integrationTests/chainSimulator/vm/egldMultiTransfer_test.go index 5e9d641f90a..f664eff7539 100644 --- a/integrationTests/chainSimulator/vm/egldMultiTransfer_test.go +++ b/integrationTests/chainSimulator/vm/egldMultiTransfer_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/node/chainSimulator" "github.com/multiversx/mx-chain-go/node/chainSimulator/components/api" "github.com/multiversx/mx-chain-go/node/chainSimulator/configs" + "github.com/multiversx/mx-chain-go/node/chainSimulator/dtos" "github.com/multiversx/mx-chain-go/vm" "github.com/stretchr/testify/require" ) @@ -761,3 +762,88 @@ func TestChainSimulator_IssueToken_EGLDTicker(t *testing.T) { require.Equal(t, "success", txResult.Status.String()) } + +func TestScCallTransferValueESDT(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + roundDurationInMillis := uint64(6000) + roundsPerEpochOpt := core.OptionalUint64{ + HasValue: true, + Value: 20, + } + + cs, err := chainSimulator.NewChainSimulator(chainSimulator.ArgsChainSimulator{ + BypassTxSignatureCheck: true, + TempDir: t.TempDir(), + PathToInitialConfig: defaultPathToInitialConfig, + NumOfShards: 3, + GenesisTimestamp: time.Now().Unix(), + RoundDurationInMillis: roundDurationInMillis, + RoundsPerEpoch: roundsPerEpochOpt, + ApiInterface: api.NewNoApiInterface(), + MinNodesPerShard: 3, + MetaChainMinNodes: 3, + NumNodesWaitingListMeta: 3, + NumNodesWaitingListShard: 3, + + InitialEpoch: 1700, + InitialNonce: 1700, + InitialRound: 1700, + }) + require.NoError(t, err) + require.NotNil(t, cs) + + err = cs.GenerateBlocks(1) + require.NoError(t, err) + + nonce := uint64(0) + err = cs.SetStateMultiple([]*dtos.AddressState{ + { + Address: "erd1qqqqqqqqqqqqqpgqw7vrdzlqg4f8zja8qgpw2cdqpcp5xhvrvcqs984824", + Nonce: &nonce, + Balance: "0", + Code: "0061736d01000000016b116000017f60000060027f7f017f60027f7f0060017f0060017f017f60047f7f7f7f0060037f7f7f017f60027f7e0060037f7f7f0060047f7f7f7f017f60067e7f7f7f7f7f017f60057f7f7e7f7f017f6000017e60057f7f7f7f7f0060047f7f7f7e0060057e7f7f7f7f0002bd041803656e760a6d4275666665724e6577000003656e760d6d427566666572417070656e64000203656e760d6d616e6167656443616c6c6572000403656e76126d427566666572417070656e644279746573000703656e76126d616e616765645369676e616c4572726f72000403656e76126d427566666572476574417267756d656e74000203656e76106d4275666665724765744c656e677468000503656e7619626967496e74476574556e7369676e6564417267756d656e74000303656e760f6765744e756d417267756d656e7473000003656e760b7369676e616c4572726f72000303656e761b6d616e61676564457865637574654f6e44657374436f6e74657874000b03656e760f6d4275666665725365744279746573000703656e76196d42756666657246726f6d426967496e74556e7369676e6564000203656e760e626967496e74536574496e743634000803656e76106d616e61676564534341646472657373000403656e760e636865636b4e6f5061796d656e74000103656e761b6d616e616765645472616e7366657256616c756545786563757465000c03656e761c6d616e616765644765744d756c74694553445443616c6c56616c7565000403656e76096d4275666665724571000203656e7609626967496e74416464000903656e760a6765744761734c656674000d03656e760f636c65616e52657475726e44617461000103656e76136d42756666657253746f7261676553746f7265000203656e76136d42756666657247657442797465536c696365000a031e1d05000002040509000605030e05030a06060f081000030300010101010105030100030616037f01418080080b7f00419582080b7f0041a082080b075608066d656d6f7279020004696e697400300775706772616465003107666f7277617264003208726563656976656400330863616c6c4261636b00340a5f5f646174615f656e6403010b5f5f686561705f6261736503020a84151d0f01017f10002201200010011a20010b0c01017f101a2200100220000b1901017f419082084190820828020041016b220036020020000b1101017f101a220220002001100b1a20020b1400100820004604400f0b41d6800841191009000b2b01027f2000419482082d0000220171200041ff01714622024504404194820820002001723a00000b20020b180020012002101b21012000101f360204200020013602000b080041014100101b0b1e00101f1a200220032802001021102220002002360204200020013602000b0f01017f101a22012000100c1a20010b4601017f230041106b220224002002200141187420014180fe03714108747220014108764180fe03712001411876727236020c20002002410c6a410410031a200241106a24000b8e0101037f230041106b220524000240200310240d00200220031025200410062107410021030340200320074f0d012005410036020c200420032005410c6a410410261a2002200528020c220641187420064180fe03714108747220064108764180fe0371200641187672721025200341046a21030c000b000b2000200236020420002001360200200541106a24000b070020001006450b0d00101f1a20002001101810220b0f00200020012003200210174100470b1e00101f1a200220032802001018102220002002360204200020013602000b1b00101f1a200220031018102220002002360204200020013602000b2001017f101f22042003102a20022004102220002002360204200020013602000bff0102027f017e230041106b220324002003200142388620014280fe0383422886842001428080fc0783421886200142808080f80f834208868484200142088842808080f80f832001421888428080fc078384200142288822044280fe03832001423888848484370308200041002001428080808080808080015422002001423088a741ff01711b220220006a410020022004a741ff01711b22006a410020002001422088a741ff01711b22006a410020002001a722004118761b22026a41002002200041107641ff01711b22026a41002002200041087641ff01711b22006a200041002001501b6a2200200341086a6a410820006b100b1a200341106a24000b110020002001200220032004101a100a1a0b0a0041764200100d41760b7001037f230041106b22022400200020012802042204200128020849047f200241086a2203420037030020024200370300200128020020042002411010261a2001200441106a36020420002002290300370001200041096a200329030037000041010541000b3a0000200241106a24000bb90102017f017e2000200128000c220241187420024180fe03714108747220024108764180fe03712002411876727236020c20002001280000220241187420024180fe03714108747220024108764180fe03712002411876727236020820002001290004220342388620034280fe0383422886842003428080fc0783421886200342808080f80f834208868484200342088842808080f80f832003421888428080fc07838420034228884280fe038320034238888484843703000b0c01017f101a2200100e20000b0800100f4100101c0b3301037f100f4100101c1019210141e58108412a101b2100101f2102200010244504402001102c42a0c21e2000200210101a0b0b930b020a7f027e230041d0016b220024004101101c4100101a220110051a20011006412047044041bc80084117101b220041d58108411010031a200041d38008410310031a200041bb8108411010031a20001004000b200121034102101d450440415a10110b02404104101d0d00415841b18008410b100b1a2000415a100636029801200042daffffff0f370290010340200041b8016a20004190016a102d20002d00b8014101470d01415820002800b901220141187420014180fe03714108747220014108764180fe037120014118767272101241004c0d000b4199800841181009000b1019101821010240200341feffffff07470440200041f8006a41d181084104101e200041f0006a2000280278200028027c200110282000280274210520002802702106101f21022000415a100636028c01200042daffffff0f37028401200041c0016a210820004191016a2107034020004190016a20004184016a102d20002d0090014101460440200041b0016a200741086a290000370300200020072900003703a8012008200041a8016a102e20002903c001210a20002802cc01210920002802c80110182104101a22014200100d20012001200910132000200a423886200a4280fe038342288684200a428080fc0783421886200a42808080f80f834208868484200a42088842808080f80f83200a421888428080fc078384200a4228884280fe0383200a423888848484370294012000200441187420044180fe03714108747220044108764180fe037120044118767272360290012000200141187420014180fe03714108747220014108764180fe03712001411876727236029c01200220004190016a411010031a0c010b0b1014220a42a08d067d200a200a42a08d06561b210b0240024002400240200210064104760e020102000b200041186a41ef80084114101e20002802182104200028021c2101101f1a2001200310181022200210062103101f22072003410476ad102a20012007102220002002100636029801200041003602940120002002360290010340200041b8016a20004190016a102d20002d00b801410146044020002800c501210220002900bd01210a200120002800b901220341187420034180fe03714108747220034108764180fe0371200341187672721025200041086a20042001200a423886200a4280fe038342288684200a428080fc0783421886200a42808080f80f834208868484200a42088842808080f80f83200a421888428080fc078384200a4228884280fe0383200a423888848484102920002802082104200028020c2101101f1a2001200241187420024180fe03714108747220024108764180fe037120024118767272102110220c010b0b200041106a200420012006200510232000280214210120002802102102200b102f102c20022001102b0c020b200b2003102c20062005102b0c010b20004198016a420037030020004200370390012002410020004190016a411010260d02200041c0016a20004190016a102e200041b0016a2201200041c8016a290300370300200020002903c001220a3703a801200041b4016a2102200a500440200041386a41928108410c101e200041306a2000280238200028023c20011027200041286a2000280230200028023420021020200041206a2000280228200028022c2006200510232000280224210120002802202102200b2003102c20022001102b0c010b200041e8006a41838108410f101e200041e0006a2000280268200028026c20011027200041d8006a20002802602000280264200a1029200041d0006a2000280258200028025c20021020200041c8006a2000280250200028025420031028200041406b2000280248200028024c2006200510232000280244210120002802402102200b102f102c20022001102b0b1015200041d0016a24000f0b4180800841191009000b419e8108411d1009000b2101017f100f4101101c4100101a2200100741cb81084106101b2000102110161a0b0300010b0ba3020200418080080b8f02726563697069656e742061646472657373206e6f7420736574756e65787065637465642045474c44207472616e7366657245474c442d303030303030617267756d656e74206465636f6465206572726f722028293a2077726f6e67206e756d626572206f6620617267756d656e74734d756c7469455344544e46545472616e73666572455344544e46545472616e73666572455344545472616e736665724d616e6167656456656320696e646578206f7574206f662072616e6765626164206172726179206c656e677468616d6f756e747465737464756d6d795f73635f61646472657373455344545472616e7366657240353535333434343332443333333533303633333436354030463432343000419082080b0438ffffff", + RootHash: "", + CodeMetadata: "BQQ=", + CodeHash: "KQsZ3JMD3ojAR5MfScgQH0o3XLUpvP1H7flxxt0qe80=", + DeveloperRewards: "4354035000000", + Owner: "erd1tkc62psh0flcj6anm6gt227gqqu7sp4xc3c3cc0fcmgk9ax6vcqs2w8h2s", + }, + { + Address: "erd1tkc62psh0flcj6anm6gt227gqqu7sp4xc3c3cc0fcmgk9ax6vcqs2w8h2s", + Nonce: &nonce, + Balance: "191060078069461323", + }, + }) + require.NoError(t, err) + + err = cs.GenerateBlocks(1) + require.NoError(t, err) + + pubKeyConverter := cs.GetNodeHandler(0).GetCoreComponents().AddressPubKeyConverter() + sndBech := "erd1tkc62psh0flcj6anm6gt227gqqu7sp4xc3c3cc0fcmgk9ax6vcqs2w8h2s" + snd, _ := pubKeyConverter.Decode(sndBech) + rcv, _ := pubKeyConverter.Decode("erd1qqqqqqqqqqqqqpgqw7vrdzlqg4f8zja8qgpw2cdqpcp5xhvrvcqs984824") + + tx := &transaction.Transaction{ + Nonce: 0, + Value: big.NewInt(0), + SndAddr: snd, + RcvAddr: rcv, + Data: []byte("upgradeContract@0061736d01000000016b116000017f60000060027f7f017f60027f7f0060017f0060017f017f60047f7f7f7f0060037f7f7f017f60027f7e0060037f7f7f0060047f7f7f7f017f60067e7f7f7f7f7f017f60057f7f7e7f7f017f6000017e60057f7f7f7f7f0060047f7f7f7e0060057e7f7f7f7f0002bd041803656e760a6d4275666665724e6577000003656e760d6d427566666572417070656e64000203656e760d6d616e6167656443616c6c6572000403656e76126d427566666572417070656e644279746573000703656e76126d616e616765645369676e616c4572726f72000403656e76126d427566666572476574417267756d656e74000203656e76106d4275666665724765744c656e677468000503656e7619626967496e74476574556e7369676e6564417267756d656e74000303656e760f6765744e756d417267756d656e7473000003656e760b7369676e616c4572726f72000303656e761b6d616e61676564457865637574654f6e44657374436f6e74657874000b03656e760f6d4275666665725365744279746573000703656e76196d42756666657246726f6d426967496e74556e7369676e6564000203656e760e626967496e74536574496e743634000803656e76106d616e61676564534341646472657373000403656e760e636865636b4e6f5061796d656e74000103656e761b6d616e616765645472616e7366657256616c756545786563757465000c03656e761c6d616e616765644765744d756c74694553445443616c6c56616c7565000403656e76096d4275666665724571000203656e7609626967496e74416464000903656e760a6765744761734c656674000d03656e760f636c65616e52657475726e44617461000103656e76136d42756666657253746f7261676553746f7265000203656e76136d42756666657247657442797465536c696365000a031e1d05000002040509000605030e05030a06060f081000030300010101010105030100030616037f01418080080b7f00419582080b7f0041a082080b075608066d656d6f7279020004696e697400300775706772616465003107666f7277617264003208726563656976656400330863616c6c4261636b00340a5f5f646174615f656e6403010b5f5f686561705f6261736503020a84151d0f01017f10002201200010011a20010b0c01017f101a2200100220000b1901017f419082084190820828020041016b220036020020000b1101017f101a220220002001100b1a20020b1400100820004604400f0b41d6800841191009000b2b01027f2000419482082d0000220171200041ff01714622024504404194820820002001723a00000b20020b180020012002101b21012000101f360204200020013602000b080041014100101b0b1e00101f1a200220032802001021102220002002360204200020013602000b0f01017f101a22012000100c1a20010b4601017f230041106b220224002002200141187420014180fe03714108747220014108764180fe03712001411876727236020c20002002410c6a410410031a200241106a24000b8e0101037f230041106b220524000240200310240d00200220031025200410062107410021030340200320074f0d012005410036020c200420032005410c6a410410261a2002200528020c220641187420064180fe03714108747220064108764180fe0371200641187672721025200341046a21030c000b000b2000200236020420002001360200200541106a24000b070020001006450b0d00101f1a20002001101810220b0f00200020012003200210174100470b1e00101f1a200220032802001018102220002002360204200020013602000b1b00101f1a200220031018102220002002360204200020013602000b2001017f101f22042003102a20022004102220002002360204200020013602000bff0102027f017e230041106b220324002003200142388620014280fe0383422886842001428080fc0783421886200142808080f80f834208868484200142088842808080f80f832001421888428080fc078384200142288822044280fe03832001423888848484370308200041002001428080808080808080015422002001423088a741ff01711b220220006a410020022004a741ff01711b22006a410020002001422088a741ff01711b22006a410020002001a722004118761b22026a41002002200041107641ff01711b22026a41002002200041087641ff01711b22006a200041002001501b6a2200200341086a6a410820006b100b1a200341106a24000b110020002001200220032004101a100a1a0b0a0041764200100d41760b7001037f230041106b22022400200020012802042204200128020849047f200241086a2203420037030020024200370300200128020020042002411010261a2001200441106a36020420002002290300370001200041096a200329030037000041010541000b3a0000200241106a24000bb90102017f017e2000200128000c220241187420024180fe03714108747220024108764180fe03712002411876727236020c20002001280000220241187420024180fe03714108747220024108764180fe03712002411876727236020820002001290004220342388620034280fe0383422886842003428080fc0783421886200342808080f80f834208868484200342088842808080f80f832003421888428080fc07838420034228884280fe038320034238888484843703000b0c01017f101a2200100e20000b0800100f4100101c0b3301037f100f4100101c1019210141e58108412a101b2100101f2102200010244504402001102c42a0c21e2000200210101a0b0b930b020a7f027e230041d0016b220024004101101c4100101a220110051a20011006412047044041bc80084117101b220041d58108411010031a200041d38008410310031a200041bb8108411010031a20001004000b200121034102101d450440415a10110b02404104101d0d00415841b18008410b100b1a2000415a100636029801200042daffffff0f370290010340200041b8016a20004190016a102d20002d00b8014101470d01415820002800b901220141187420014180fe03714108747220014108764180fe037120014118767272101241004c0d000b4199800841181009000b1019101821010240200341feffffff07470440200041f8006a41d181084104101e200041f0006a2000280278200028027c200110282000280274210520002802702106101f21022000415a100636028c01200042daffffff0f37028401200041c0016a210820004191016a2107034020004190016a20004184016a102d20002d0090014101460440200041b0016a200741086a290000370300200020072900003703a8012008200041a8016a102e20002903c001210a20002802cc01210920002802c80110182104101a22014200100d20012001200910132000200a423886200a4280fe038342288684200a428080fc0783421886200a42808080f80f834208868484200a42088842808080f80f83200a421888428080fc078384200a4228884280fe0383200a423888848484370294012000200441187420044180fe03714108747220044108764180fe037120044118767272360290012000200141187420014180fe03714108747220014108764180fe03712001411876727236029c01200220004190016a411010031a0c010b0b1014220a42a08d067d200a200a42a08d06561b210b0240024002400240200210064104760e020102000b200041186a41ef80084114101e20002802182104200028021c2101101f1a2001200310181022200210062103101f22072003410476ad102a20012007102220002002100636029801200041003602940120002002360290010340200041b8016a20004190016a102d20002d00b801410146044020002800c501210220002900bd01210a200120002800b901220341187420034180fe03714108747220034108764180fe0371200341187672721025200041086a20042001200a423886200a4280fe038342288684200a428080fc0783421886200a42808080f80f834208868484200a42088842808080f80f83200a421888428080fc078384200a4228884280fe0383200a423888848484102920002802082104200028020c2101101f1a2001200241187420024180fe03714108747220024108764180fe037120024118767272102110220c010b0b200041106a200420012006200510232000280214210120002802102102200b102f102c20022001102b0c020b200b2003102c20062005102b0c010b20004198016a420037030020004200370390012002410020004190016a411010260d02200041c0016a20004190016a102e200041b0016a2201200041c8016a290300370300200020002903c001220a3703a801200041b4016a2102200a500440200041386a41928108410c101e200041306a2000280238200028023c20011027200041286a2000280230200028023420021020200041206a2000280228200028022c2006200510232000280224210120002802202102200b2003102c20022001102b0c010b200041e8006a41838108410f101e200041e0006a2000280268200028026c20011027200041d8006a20002802602000280264200a1029200041d0006a2000280258200028025c20021020200041c8006a2000280250200028025420031028200041406b2000280248200028024c2006200510232000280244210120002802402102200b102f102c20022001102b0b1015200041d0016a24000f0b4180800841191009000b419e8108411d1009000b2101017f100f4101101c4100101a2200100741cb81084106101b2000102110161a0b0300010b0ba3020200418080080b8f02726563697069656e742061646472657373206e6f7420736574756e65787065637465642045474c44207472616e7366657245474c442d303030303030617267756d656e74206465636f6465206572726f722028293a2077726f6e67206e756d626572206f6620617267756d656e74734d756c7469455344544e46545472616e73666572455344544e46545472616e73666572455344545472616e736665724d616e6167656456656320696e646578206f7574206f662072616e6765626164206172726179206c656e677468616d6f756e747465737464756d6d795f73635f61646472657373455344545472616e7366657240353535333434343332443333333533303633333436354030463432343000419082080b0438ffffff@0504"), + GasLimit: 100_000_000, + GasPrice: minGasPrice, + ChainID: []byte(configs.ChainID), + Version: 1, + Signature: []byte("dummy"), + } + + txResult, err := cs.SendTxAndGenerateBlockTilTxIsExecuted(tx, maxNumOfBlockToGenerateWhenExecutingTx) + require.Nil(t, err) + require.NotNil(t, txResult) + require.Equal(t, "success", txResult.Status.String()) + require.Equal(t, core.SignalErrorOperation, txResult.Logs.Events[0].Identifier) + require.Equal(t, "transfer value on esdt call", string(txResult.Logs.Events[0].Topics[1])) +} diff --git a/integrationTests/chainSimulator/vm/esdtImprovements_test.go b/integrationTests/chainSimulator/vm/esdtImprovements_test.go index 8361feb4fee..80748a22d7c 100644 --- a/integrationTests/chainSimulator/vm/esdtImprovements_test.go +++ b/integrationTests/chainSimulator/vm/esdtImprovements_test.go @@ -12,6 +12,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/esdt" "github.com/multiversx/mx-chain-core-go/data/transaction" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/config" testsChainSimulator "github.com/multiversx/mx-chain-go/integrationTests/chainSimulator" "github.com/multiversx/mx-chain-go/integrationTests/vm/txsFee" @@ -23,8 +26,6 @@ import ( "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/vm" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/require" ) const ( diff --git a/integrationTests/consensus/consensusSigning_test.go b/integrationTests/consensus/consensusSigning_test.go index 5091d2fafcc..4d6b57d3929 100644 --- a/integrationTests/consensus/consensusSigning_test.go +++ b/integrationTests/consensus/consensusSigning_test.go @@ -2,13 +2,16 @@ package consensus import ( "bytes" + "encoding/hex" "fmt" - "sync" "testing" "time" - "github.com/multiversx/mx-chain-go/integrationTests" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/integrationTests" ) func initNodesWithTestSigner( @@ -18,41 +21,54 @@ func initNodesWithTestSigner( numInvalid uint32, roundTime uint64, consensusType string, -) map[uint32][]*integrationTests.TestConsensusNode { +) (map[uint32][]*integrationTests.TestFullNode, map[string]struct{}) { fmt.Println("Step 1. Setup nodes...") - nodes := integrationTests.CreateNodesWithTestConsensusNode( + equivalentProofsActivationEpoch := uint32(0) + + enableEpochsConfig := integrationTests.CreateEnableEpochsConfig() + enableEpochsConfig.AndromedaEnableEpoch = equivalentProofsActivationEpoch + + nodes := integrationTests.CreateNodesWithTestFullNode( int(numMetaNodes), int(numNodes), int(consensusSize), roundTime, consensusType, 1, + enableEpochsConfig, + false, ) - for shardID, nodesList := range nodes { - displayAndStartNodes(shardID, nodesList) - } - time.Sleep(p2pBootstrapDelay) + invalidNodesAddresses := make(map[string]struct{}) + for shardID := range nodes { if numInvalid < numNodes { for i := uint32(0); i < numInvalid; i++ { ii := numNodes - i - 1 nodes[shardID][ii].MultiSigner.CreateSignatureShareCalled = func(privateKeyBytes, message []byte) ([]byte, error) { - // sig share with invalid size, it will be dropped but user won't be blacklisted - invalidSigShare := bytes.Repeat([]byte("a"), 3) - log.Warn("invalid sig share from ", "pk", getPkEncoded(nodes[shardID][ii].NodeKeys.Pk), "sig", invalidSigShare) + var invalidSigShare []byte + if i%2 == 0 { + // invalid sig share but with valid format + invalidSigShare, _ = hex.DecodeString("2ee350b9a821e20df97ba487a80b0d0ffffca7da663185cf6a562edc7c2c71e3ca46ed71b31bccaf53c626b87f2b6e08") + } else { + // sig share with invalid size + invalidSigShare = bytes.Repeat([]byte("a"), 3) + } + log.Warn("invalid sig share from ", "pk", nodes[shardID][ii].NodeKeys.MainKey.Pk, "sig", invalidSigShare) return invalidSigShare, nil } + + invalidNodesAddresses[string(nodes[shardID][ii].OwnAccount.Address)] = struct{}{} } } } - return nodes + return nodes, invalidNodesAddresses } func TestConsensusWithInvalidSigners(t *testing.T) { @@ -65,9 +81,8 @@ func TestConsensusWithInvalidSigners(t *testing.T) { consensusSize := uint32(4) numInvalid := uint32(1) roundTime := uint64(1000) - numCommBlock := uint64(8) - nodes := initNodesWithTestSigner(numMetaNodes, numNodes, consensusSize, numInvalid, roundTime, blsConsensusType) + nodes, invalidNodesAddresses := initNodesWithTestSigner(numMetaNodes, numNodes, consensusSize, numInvalid, roundTime, blsConsensusType) defer func() { for shardID := range nodes { @@ -82,27 +97,34 @@ func TestConsensusWithInvalidSigners(t *testing.T) { fmt.Println("Start consensus...") time.Sleep(time.Second) - for shardID := range nodes { - mutex := &sync.Mutex{} - nonceForRoundMap := make(map[uint64]uint64) - totalCalled := 0 - - err := startNodesWithCommitBlock(nodes[shardID], mutex, nonceForRoundMap, &totalCalled) - assert.Nil(t, err) - - chDone := make(chan bool) - go checkBlockProposedEveryRound(numCommBlock, nonceForRoundMap, mutex, chDone, t) - - extraTime := uint64(2) - endTime := time.Duration(roundTime)*time.Duration(numCommBlock+extraTime)*time.Millisecond + time.Minute - select { - case <-chDone: - case <-time.After(endTime): - mutex.Lock() - log.Error("currently saved nonces for rounds", "nonceForRoundMap", nonceForRoundMap) - assert.Fail(t, "consensus too slow, not working.") - mutex.Unlock() - return + for _, nodesList := range nodes { + for _, n := range nodesList { + err := startFullConsensusNode(n) + require.Nil(t, err) + } + } + + fmt.Println("Wait for several rounds...") + + time.Sleep(15 * time.Second) + + fmt.Println("Checking shards...") + + expectedNonce := uint64(10) + for _, nodesList := range nodes { + for _, n := range nodesList { + for i := 1; i < len(nodes); i++ { + _, ok := invalidNodesAddresses[string(n.OwnAccount.Address)] + if ok { + continue + } + + if check.IfNil(n.Node.GetDataComponents().Blockchain().GetCurrentBlockHeader()) { + assert.Fail(t, fmt.Sprintf("Node with idx %d does not have a current block", i)) + } else { + assert.GreaterOrEqual(t, n.Node.GetDataComponents().Blockchain().GetCurrentBlockHeader().GetNonce(), expectedNonce) + } + } } } } diff --git a/integrationTests/consensus/consensus_test.go b/integrationTests/consensus/consensus_test.go index a94c5717efe..e1cb29a0611 100644 --- a/integrationTests/consensus/consensus_test.go +++ b/integrationTests/consensus/consensus_test.go @@ -8,16 +8,19 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/pubkeyConverter" "github.com/multiversx/mx-chain-core-go/data" crypto "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/config" consensusComp "github.com/multiversx/mx-chain-go/factory/consensus" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/process" consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/assert" ) const ( @@ -31,17 +34,184 @@ var ( log = logger.GetOrCreate("integrationtests/consensus") ) -func encodeAddress(address []byte) string { - return hex.EncodeToString(address) +func TestConsensusBLSFullTestSingleKeys(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + runFullConsensusTest(t, blsConsensusType, 1) } -func getPkEncoded(pubKey crypto.PublicKey) string { - pk, err := pubKey.ToByteArray() +func TestConsensusBLSFullTestMultiKeys(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + runFullConsensusTest(t, blsConsensusType, 5) +} + +func TestConsensusBLSNotEnoughValidators(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + runConsensusWithNotEnoughValidators(t, blsConsensusType) +} + +func TestConsensusBLSWithFullProcessing_BeforeEquivalentProofs(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + testConsensusBLSWithFullProcessing(t, integrationTests.UnreachableEpoch, 1) +} + +func TestConsensusBLSWithFullProcessing_WithEquivalentProofs(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + testConsensusBLSWithFullProcessing(t, uint32(0), 1) +} + +func TestConsensusBLSWithFullProcessing_WithEquivalentProofs_MultiKeys(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + testConsensusBLSWithFullProcessing(t, uint32(0), 3) +} + +func testConsensusBLSWithFullProcessing(t *testing.T, equivalentProofsActivationEpoch uint32, numKeysOnEachNode int) { + numMetaNodes := uint32(2) + numNodes := uint32(2) + consensusSize := uint32(2 * numKeysOnEachNode) + roundTime := uint64(1000) + + log.Info("runFullNodesTest", + "numNodes", numNodes, + "numKeysOnEachNode", numKeysOnEachNode, + "consensusSize", consensusSize, + ) + + enableEpochsConfig := integrationTests.CreateEnableEpochsConfig() + + enableEpochsConfig.AndromedaEnableEpoch = equivalentProofsActivationEpoch + + fmt.Println("Step 1. Setup nodes...") + + nodes := integrationTests.CreateNodesWithTestFullNode( + int(numMetaNodes), + int(numNodes), + int(consensusSize), + roundTime, + blsConsensusType, + numKeysOnEachNode, + enableEpochsConfig, + true, + ) + + for shardID, nodesList := range nodes { + for _, n := range nodesList { + skBuff, _ := n.NodeKeys.MainKey.Sk.ToByteArray() + pkBuff, _ := n.NodeKeys.MainKey.Pk.ToByteArray() + + encodedNodePkBuff := testPubkeyConverter.SilentEncode(pkBuff, log) + + fmt.Printf("Shard ID: %v, sk: %s, pk: %s\n", + shardID, + hex.EncodeToString(skBuff), + encodedNodePkBuff, + ) + } + } + + time.Sleep(p2pBootstrapDelay) + + defer func() { + for _, nodesList := range nodes { + for _, n := range nodesList { + n.Close() + } + } + }() + + for _, nodesList := range nodes { + for _, n := range nodesList { + err := startFullConsensusNode(n) + require.Nil(t, err) + } + } + + fmt.Println("Wait for several rounds...") + + time.Sleep(15 * time.Second) + + fmt.Println("Checking shards...") + + expectedNonce := uint64(10) + for _, nodesList := range nodes { + for _, n := range nodesList { + for i := 1; i < len(nodes); i++ { + if check.IfNil(n.Node.GetDataComponents().Blockchain().GetCurrentBlockHeader()) { + assert.Fail(t, fmt.Sprintf("Node with idx %d does not have a current block", i)) + } else { + assert.GreaterOrEqual(t, n.Node.GetDataComponents().Blockchain().GetCurrentBlockHeader().GetNonce(), expectedNonce) + } + } + } + } +} + +func startFullConsensusNode( + n *integrationTests.TestFullNode, +) error { + statusComponents := integrationTests.GetDefaultStatusComponents() + + consensusArgs := consensusComp.ConsensusComponentsFactoryArgs{ + Config: config.Config{ + Consensus: config.ConsensusConfig{ + Type: blsConsensusType, + }, + ValidatorPubkeyConverter: config.PubkeyConfig{ + Length: 96, + Type: "bls", + SignatureLength: 48, + }, + TrieSync: config.TrieSyncConfig{ + NumConcurrentTrieSyncers: 5, + MaxHardCapForMissingNodes: 5, + TrieSyncerVersion: 2, + CheckNodesOnDisk: false, + }, + GeneralSettings: config.GeneralSettingsConfig{ + SyncProcessTimeInMillis: 6000, + }, + }, + BootstrapRoundIndex: 0, + CoreComponents: n.Node.GetCoreComponents(), + NetworkComponents: n.Node.GetNetworkComponents(), + CryptoComponents: n.Node.GetCryptoComponents(), + DataComponents: n.Node.GetDataComponents(), + ProcessComponents: n.Node.GetProcessComponents(), + StateComponents: n.Node.GetStateComponents(), + StatusComponents: statusComponents, + StatusCoreComponents: n.Node.GetStatusCoreComponents(), + ScheduledProcessor: &consensusMocks.ScheduledProcessorStub{}, + IsInImportMode: n.Node.IsInImportMode(), + } + + consensusFactory, err := consensusComp.NewConsensusComponentsFactory(consensusArgs) if err != nil { - return err.Error() + return err } - return encodeAddress(pk) + managedConsensusComponents, err := consensusComp.NewManagedConsensusComponents(consensusFactory) + if err != nil { + return err + } + + return managedConsensusComponents.Create() } func initNodesAndTest( @@ -52,6 +222,7 @@ func initNodesAndTest( roundTime uint64, consensusType string, numKeysOnEachNode int, + enableEpochsConfig config.EnableEpochs, ) map[uint32][]*integrationTests.TestConsensusNode { fmt.Println("Step 1. Setup nodes...") @@ -63,6 +234,7 @@ func initNodesAndTest( roundTime, consensusType, numKeysOnEachNode, + enableEpochsConfig, ) for shardID, nodesList := range nodes { @@ -215,10 +387,14 @@ func checkBlockProposedEveryRound(numCommBlock uint64, nonceForRoundMap map[uint } } -func runFullConsensusTest(t *testing.T, consensusType string, numKeysOnEachNode int) { +func runFullConsensusTest( + t *testing.T, + consensusType string, + numKeysOnEachNode int, +) { numMetaNodes := uint32(4) numNodes := uint32(4) - consensusSize := uint32(4 * numKeysOnEachNode) + consensusSize := uint32(3 * numKeysOnEachNode) numInvalid := uint32(0) roundTime := uint64(1000) numCommBlock := uint64(8) @@ -229,7 +405,21 @@ func runFullConsensusTest(t *testing.T, consensusType string, numKeysOnEachNode "consensusSize", consensusSize, ) - nodes := initNodesAndTest(numMetaNodes, numNodes, consensusSize, numInvalid, roundTime, consensusType, numKeysOnEachNode) + enableEpochsConfig := integrationTests.CreateEnableEpochsConfig() + + equivalentProofsActivationEpoch := integrationTests.UnreachableEpoch + enableEpochsConfig.AndromedaEnableEpoch = equivalentProofsActivationEpoch + + nodes := initNodesAndTest( + numMetaNodes, + numNodes, + consensusSize, + numInvalid, + roundTime, + consensusType, + numKeysOnEachNode, + enableEpochsConfig, + ) defer func() { for shardID := range nodes { @@ -270,29 +460,15 @@ func runFullConsensusTest(t *testing.T, consensusType string, numKeysOnEachNode } } -func TestConsensusBLSFullTestSingleKeys(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - runFullConsensusTest(t, blsConsensusType, 1) -} - -func TestConsensusBLSFullTestMultiKeys(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - runFullConsensusTest(t, blsConsensusType, 5) -} - func runConsensusWithNotEnoughValidators(t *testing.T, consensusType string) { numMetaNodes := uint32(4) numNodes := uint32(4) consensusSize := uint32(4) numInvalid := uint32(2) roundTime := uint64(1000) - nodes := initNodesAndTest(numMetaNodes, numNodes, consensusSize, numInvalid, roundTime, consensusType, 1) + enableEpochsConfig := integrationTests.CreateEnableEpochsConfig() + enableEpochsConfig.AndromedaEnableEpoch = integrationTests.UnreachableEpoch + nodes := initNodesAndTest(numMetaNodes, numNodes, consensusSize, numInvalid, roundTime, consensusType, 1, enableEpochsConfig) defer func() { for shardID := range nodes { @@ -325,14 +501,6 @@ func runConsensusWithNotEnoughValidators(t *testing.T, consensusType string) { } } -func TestConsensusBLSNotEnoughValidators(t *testing.T) { - if testing.Short() { - t.Skip("this is not a short test") - } - - runConsensusWithNotEnoughValidators(t, blsConsensusType) -} - func displayAndStartNodes(shardID uint32, nodes []*integrationTests.TestConsensusNode) { for _, n := range nodes { skBuff, _ := n.NodeKeys.Sk.ToByteArray() @@ -347,3 +515,16 @@ func displayAndStartNodes(shardID uint32, nodes []*integrationTests.TestConsensu ) } } + +func encodeAddress(address []byte) string { + return hex.EncodeToString(address) +} + +func getPkEncoded(pubKey crypto.PublicKey) string { + pk, err := pubKey.ToByteArray() + if err != nil { + return err.Error() + } + + return encodeAddress(pk) +} diff --git a/integrationTests/countInterceptor.go b/integrationTests/countInterceptor.go index fba328de387..f7a7d5fc6ee 100644 --- a/integrationTests/countInterceptor.go +++ b/integrationTests/countInterceptor.go @@ -4,6 +4,7 @@ import ( "sync" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/p2p" ) @@ -21,12 +22,12 @@ func NewCountInterceptor() *CountInterceptor { } // ProcessReceivedMessage is called each time a new message is received -func (ci *CountInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) error { +func (ci *CountInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) ([]byte, error) { ci.mutMessagesCount.Lock() ci.messagesCount[message.Topic()]++ ci.mutMessagesCount.Unlock() - return nil + return nil, nil } // MessageCount returns the number of messages received on the provided topic diff --git a/integrationTests/factory/bootstrapComponents/bootstrapComponents_test.go b/integrationTests/factory/bootstrapComponents/bootstrapComponents_test.go index 03601ec46b1..704db5455b8 100644 --- a/integrationTests/factory/bootstrapComponents/bootstrapComponents_test.go +++ b/integrationTests/factory/bootstrapComponents/bootstrapComponents_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests/factory" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/stretchr/testify/require" ) // ------------ Test BootstrapComponents -------------------- diff --git a/integrationTests/factory/componentsHelper.go b/integrationTests/factory/componentsHelper.go index 6ad6c5910bf..3006dd3182c 100644 --- a/integrationTests/factory/componentsHelper.go +++ b/integrationTests/factory/componentsHelper.go @@ -7,6 +7,7 @@ import ( "runtime/pprof" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/p2p" @@ -40,6 +41,8 @@ func CreateDefaultConfig(tb testing.TB) *config.Configs { systemSCConfig, _ := common.LoadSystemSmartContractsConfig(configPathsHolder.SystemSC) epochConfig, _ := common.LoadEpochConfig(configPathsHolder.Epoch) roundConfig, _ := common.LoadRoundConfig(configPathsHolder.RoundActivation) + var nodesConfig config.NodesConfig + _ = core.LoadJsonFile(&nodesConfig, NodesSetupPath) mainP2PConfig.KadDhtPeerDiscovery.Enabled = false prefsConfig.Preferences.DestinationShardAsObserver = "0" @@ -69,10 +72,32 @@ func CreateDefaultConfig(tb testing.TB) *config.Configs { } configs.ConfigurationPathsHolder = configPathsHolder configs.ImportDbConfig = &config.ImportDbConfig{} + configs.NodesConfig = &nodesConfig + + configs.GeneralConfig.GeneralSettings.ChainParametersByEpoch = computeChainParameters(uint32(len(configs.NodesConfig.InitialNodes)), configs.GeneralConfig.GeneralSettings.GenesisMaxNumberOfShards) return configs } +func computeChainParameters(numInitialNodes uint32, numShardsWithoutMeta uint32) []config.ChainParametersByEpochConfig { + numShardsWithMeta := numShardsWithoutMeta + 1 + nodesPerShards := numInitialNodes / numShardsWithMeta + shardCnsGroupSize := nodesPerShards + if shardCnsGroupSize > 1 { + shardCnsGroupSize-- + } + diff := numInitialNodes - nodesPerShards*numShardsWithMeta + return []config.ChainParametersByEpochConfig{ + { + ShardConsensusGroupSize: shardCnsGroupSize, + ShardMinNumNodes: nodesPerShards, + MetachainConsensusGroupSize: nodesPerShards, + MetachainMinNumNodes: nodesPerShards + diff, + RoundDuration: 2000, + }, + } +} + func createConfigurationsPathsHolder() *config.ConfigurationPathsHolder { var concatPath = func(filename string) string { return path.Join(BaseNodeConfigPath, filename) diff --git a/integrationTests/factory/consensusComponents/consensusComponents_test.go b/integrationTests/factory/consensusComponents/consensusComponents_test.go index b68e9dd95cc..a7eec6bde69 100644 --- a/integrationTests/factory/consensusComponents/consensusComponents_test.go +++ b/integrationTests/factory/consensusComponents/consensusComponents_test.go @@ -6,13 +6,14 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/dataRetriever" bootstrapComp "github.com/multiversx/mx-chain-go/factory/bootstrap" "github.com/multiversx/mx-chain-go/integrationTests/factory" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/stretchr/testify/require" ) // ------------ Test TestConsensusComponents -------------------- @@ -78,6 +79,7 @@ func TestConsensusComponents_Close_ShouldWork(t *testing.T) { managedCoreComponents.EnableEpochsHandler(), managedDataComponents.Datapool().CurrentEpochValidatorInfo(), managedBootstrapComponents.NodesCoordinatorRegistryFactory(), + managedCoreComponents.ChainParametersHandler(), ) require.Nil(t, err) managedStatusComponents, err := nr.CreateManagedStatusComponents( diff --git a/integrationTests/factory/heartbeatComponents/heartbeatComponents_test.go b/integrationTests/factory/heartbeatComponents/heartbeatComponents_test.go index dd0a07ad91f..d296be05b04 100644 --- a/integrationTests/factory/heartbeatComponents/heartbeatComponents_test.go +++ b/integrationTests/factory/heartbeatComponents/heartbeatComponents_test.go @@ -6,13 +6,14 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/dataRetriever" bootstrapComp "github.com/multiversx/mx-chain-go/factory/bootstrap" "github.com/multiversx/mx-chain-go/integrationTests/factory" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/stretchr/testify/require" ) // ------------ Test TestHeartbeatComponents -------------------- @@ -78,6 +79,7 @@ func TestHeartbeatComponents_Close_ShouldWork(t *testing.T) { managedCoreComponents.EnableEpochsHandler(), managedDataComponents.Datapool().CurrentEpochValidatorInfo(), managedBootstrapComponents.NodesCoordinatorRegistryFactory(), + managedCoreComponents.ChainParametersHandler(), ) require.Nil(t, err) managedStatusComponents, err := nr.CreateManagedStatusComponents( diff --git a/integrationTests/factory/processComponents/processComponents_test.go b/integrationTests/factory/processComponents/processComponents_test.go index 17860520ea9..6f82bbf1188 100644 --- a/integrationTests/factory/processComponents/processComponents_test.go +++ b/integrationTests/factory/processComponents/processComponents_test.go @@ -6,13 +6,14 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/dataRetriever" bootstrapComp "github.com/multiversx/mx-chain-go/factory/bootstrap" "github.com/multiversx/mx-chain-go/integrationTests/factory" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/stretchr/testify/require" ) // ------------ Test TestProcessComponents -------------------- @@ -79,6 +80,7 @@ func TestProcessComponents_Close_ShouldWork(t *testing.T) { managedCoreComponents.EnableEpochsHandler(), managedDataComponents.Datapool().CurrentEpochValidatorInfo(), managedBootstrapComponents.NodesCoordinatorRegistryFactory(), + managedCoreComponents.ChainParametersHandler(), ) require.Nil(t, err) managedStatusComponents, err := nr.CreateManagedStatusComponents( diff --git a/integrationTests/factory/stateComponents/stateComponents_test.go b/integrationTests/factory/stateComponents/stateComponents_test.go index 18984a82bde..820694aa55e 100644 --- a/integrationTests/factory/stateComponents/stateComponents_test.go +++ b/integrationTests/factory/stateComponents/stateComponents_test.go @@ -6,10 +6,11 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests/factory" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/stretchr/testify/require" ) func TestStateComponents_Create_Close_ShouldWork(t *testing.T) { diff --git a/integrationTests/factory/statusComponents/statusComponents_test.go b/integrationTests/factory/statusComponents/statusComponents_test.go index dc5d3575b8c..488d20baea7 100644 --- a/integrationTests/factory/statusComponents/statusComponents_test.go +++ b/integrationTests/factory/statusComponents/statusComponents_test.go @@ -6,13 +6,14 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/dataRetriever" bootstrapComp "github.com/multiversx/mx-chain-go/factory/bootstrap" "github.com/multiversx/mx-chain-go/integrationTests/factory" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/testscommon/goroutines" - "github.com/stretchr/testify/require" ) // ------------ Test StatusComponents -------------------- @@ -79,6 +80,7 @@ func TestStatusComponents_Create_Close_ShouldWork(t *testing.T) { managedCoreComponents.EnableEpochsHandler(), managedDataComponents.Datapool().CurrentEpochValidatorInfo(), managedBootstrapComponents.NodesCoordinatorRegistryFactory(), + managedCoreComponents.ChainParametersHandler(), ) require.Nil(t, err) managedStatusComponents, err := nr.CreateManagedStatusComponents( diff --git a/integrationTests/factory/testdata/nodesSetup.json b/integrationTests/factory/testdata/nodesSetup.json index 239fd9a52f6..2a966c72ce8 100644 --- a/integrationTests/factory/testdata/nodesSetup.json +++ b/integrationTests/factory/testdata/nodesSetup.json @@ -1,12 +1,5 @@ { "startTime": 0, - "roundDuration": 4000, - "consensusGroupSize": 3, - "minNodesPerShard": 3, - "metaChainConsensusGroupSize": 3, - "metaChainMinNodes": 3, - "hysteresis": 0, - "adaptivity": false, "initialNodes": [ { "pubkey": "cbc8c9a6a8d9c874e89eb9366139368ae728bd3eda43f173756537877ba6bca87e01a97b815c9f691df73faa16f66b15603056540aa7252d73fecf05d24cd36b44332a88386788fbdb59d04502e8ecb0132d8ebd3d875be4c83e8b87c55eb901", diff --git a/integrationTests/frontend/staking/staking_test.go b/integrationTests/frontend/staking/staking_test.go index 8cba29bd032..fa29ea091cd 100644 --- a/integrationTests/frontend/staking/staking_test.go +++ b/integrationTests/frontend/staking/staking_test.go @@ -8,12 +8,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" - "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-go/vm" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/integrationTests" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/vm" ) var log = logger.GetOrCreate("integrationtests/frontend/staking") @@ -64,11 +65,11 @@ func TestSignatureOnStaking(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -109,7 +110,7 @@ func TestSignatureOnStaking(t *testing.T) { nrRoundsToPropagateMultiShard := 10 integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) diff --git a/integrationTests/interface.go b/integrationTests/interface.go index ad90ffbb6a3..2b78eec1f0f 100644 --- a/integrationTests/interface.go +++ b/integrationTests/interface.go @@ -69,6 +69,7 @@ type Facade interface { GetAllESDTTokens(address string, options api.AccountQueryOptions) (map[string]*esdt.ESDigitalToken, api.BlockInfo, error) GetESDTsRoles(address string, options api.AccountQueryOptions) (map[string][]string, api.BlockInfo, error) GetKeyValuePairs(address string, options api.AccountQueryOptions) (map[string]string, api.BlockInfo, error) + IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions) (map[string]string, [][]byte, api.BlockInfo, error) GetGuardianData(address string, options api.AccountQueryOptions) (api.GuardianData, api.BlockInfo, error) GetBlockByHash(hash string, options api.BlockQueryOptions) (*dataApi.Block, error) GetBlockByNonce(nonce uint64, options api.BlockQueryOptions) (*dataApi.Block, error) diff --git a/integrationTests/longTests/antiflooding/antiflooding_test.go b/integrationTests/longTests/antiflooding/antiflooding_test.go index ab3ec860489..ad2fca583d2 100644 --- a/integrationTests/longTests/antiflooding/antiflooding_test.go +++ b/integrationTests/longTests/antiflooding/antiflooding_test.go @@ -21,7 +21,7 @@ import ( var log = logger.GetOrCreate("integrationtests/longtests/antiflood") //nolint -//nolint +// nolint func createWorkableConfig() config.Config { return config.Config{ Antiflood: config.AntifloodConfig{ @@ -80,7 +80,7 @@ func createWorkableConfig() config.Config { } } -//nolint +// nolint func createDisabledConfig() config.Config { return config.Config{ Antiflood: config.AntifloodConfig{ @@ -114,7 +114,7 @@ func TestAntifloodingForLargerPeriodOfTime(t *testing.T) { } } -//nolint +// nolint func createProcessors(peers []p2p.Messenger, topic string, idxBadPeers []int, idxGoodPeers []int) []*messageProcessor { processors := make([]*messageProcessor, 0, len(peers)) ctx := context.Background() @@ -155,7 +155,7 @@ func createProcessors(peers []p2p.Messenger, topic string, idxBadPeers []int, id return processors } -//nolint +// nolint func intInSlice(searchFor int, slice []int) bool { for _, val := range slice { if searchFor == val { @@ -166,7 +166,7 @@ func intInSlice(searchFor int, slice []int) bool { return false } -//nolint +// nolint func displayProcessors(processors []*messageProcessor, idxBadPeers []int, idxRound int) { header := []string{"idx", "pid", "received", "processed", "received/s", "connections"} data := make([]*display.LineData, 0, len(processors)) @@ -199,7 +199,7 @@ func displayProcessors(processors []*messageProcessor, idxBadPeers []int, idxRou time.Sleep(timeBetweenPrints) } -//nolint +// nolint func startFlooding(peers []p2p.Messenger, topic string, idxBadPeers []int, maxSize int, msgSize int) { lastUpdated := time.Now() m := make(map[core.PeerID]int) diff --git a/integrationTests/longTests/antiflooding/messageProcessor.go b/integrationTests/longTests/antiflooding/messageProcessor.go index 5c3838dea61..144cc700fa5 100644 --- a/integrationTests/longTests/antiflooding/messageProcessor.go +++ b/integrationTests/longTests/antiflooding/messageProcessor.go @@ -5,6 +5,7 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process" ) @@ -31,7 +32,7 @@ func NewMessageProcessor(antiflooder process.P2PAntifloodHandler, messenger p2p. } // ProcessReceivedMessage is the callback function from the p2p side whenever a new message is received -func (mp *messageProcessor) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, _ p2p.MessageHandler) error { +func (mp *messageProcessor) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, _ p2p.MessageHandler) ([]byte, error) { atomic.AddUint32(&mp.numMessagesReceived, 1) atomic.AddUint64(&mp.sizeMessagesReceived, uint64(len(message.Data()))) atomic.AddUint32(&mp.numMessagesReceivedPerInterval, 1) @@ -39,13 +40,13 @@ func (mp *messageProcessor) ProcessReceivedMessage(message p2p.MessageP2P, fromC err := mp.antiflooder.CanProcessMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } atomic.AddUint32(&mp.numMessagesProcessed, 1) atomic.AddUint64(&mp.sizeMessagesProcessed, uint64(len(message.Data()))) - return nil + return nil, nil } // NumMessagesProcessed returns the number of processed messages diff --git a/integrationTests/longTests/storage/storagePutRemove_test.go b/integrationTests/longTests/storage/storagePutRemove_test.go index 64c211a5a80..a10d0085ffc 100644 --- a/integrationTests/longTests/storage/storagePutRemove_test.go +++ b/integrationTests/longTests/storage/storagePutRemove_test.go @@ -58,7 +58,7 @@ func TestPutRemove(t *testing.T) { } } -//nolint +// nolint func generateValues(numPuts int, valuesPayloadSize int) map[string][]byte { m := make(map[string][]byte) for i := 0; i < numPuts; i++ { @@ -74,7 +74,7 @@ func generateValues(numPuts int, valuesPayloadSize int) map[string][]byte { return m } -//nolint +// nolint func putValues(store storage.Storer, values map[string][]byte, rmv map[int][][]byte, idx int) { hashes := make([][]byte, 0, len(rmv)) for key, val := range values { @@ -86,7 +86,7 @@ func putValues(store storage.Storer, values map[string][]byte, rmv map[int][][]b rmv[idx] = hashes } -//nolint +// nolint func removeOld(store storage.Storer, rmv map[int][][]byte, idx int) { hashes, found := rmv[idx-2] if !found { diff --git a/integrationTests/miniNetwork.go b/integrationTests/miniNetwork.go index e9c64f5606d..9424a566c07 100644 --- a/integrationTests/miniNetwork.go +++ b/integrationTests/miniNetwork.go @@ -71,10 +71,10 @@ func (n *MiniNetwork) Start() { // Continue advances processing with a number of rounds func (n *MiniNetwork) Continue(t *testing.T, numRounds int) { - idxProposers := []int{0, 1} + leaders := []*TestProcessorNode{n.Nodes[0], n.Nodes[1]} for i := int64(0); i < int64(numRounds); i++ { - n.Nonce, n.Round = ProposeAndSyncOneBlock(t, n.Nodes, idxProposers, n.Round, n.Nonce) + n.Nonce, n.Round = ProposeAndSyncOneBlock(t, n.Nodes, leaders, n.Round, n.Nonce) } } diff --git a/integrationTests/mock/blockProcessorMock.go b/integrationTests/mock/blockProcessorMock.go index fb83fcfb0af..b3f42dd8e52 100644 --- a/integrationTests/mock/blockProcessorMock.go +++ b/integrationTests/mock/blockProcessorMock.go @@ -24,6 +24,7 @@ type BlockProcessorMock struct { CreateNewHeaderCalled func(round uint64, nonce uint64) (data.HeaderHandler, error) PruneStateOnRollbackCalled func(currHeader data.HeaderHandler, currHeaderHash []byte, prevHeader data.HeaderHandler, prevHeaderHash []byte) RevertStateToBlockCalled func(header data.HeaderHandler, rootHash []byte) error + DecodeBlockHeaderCalled func(dta []byte) data.HeaderHandler } // ProcessBlock mocks processing a block @@ -137,6 +138,10 @@ func (bpm *BlockProcessorMock) DecodeBlockBody(dta []byte) data.BodyHandler { // DecodeBlockHeader method decodes block header from a given byte array func (bpm *BlockProcessorMock) DecodeBlockHeader(dta []byte) data.HeaderHandler { + if bpm.DecodeBlockHeaderCalled != nil { + return bpm.DecodeBlockHeaderCalled(dta) + } + if dta == nil { return nil } diff --git a/integrationTests/mock/coreComponentsStub.go b/integrationTests/mock/coreComponentsStub.go index dca3f5a1fa6..f221a77610b 100644 --- a/integrationTests/mock/coreComponentsStub.go +++ b/integrationTests/mock/coreComponentsStub.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/factory" @@ -54,6 +55,10 @@ type CoreComponentsStub struct { ProcessStatusHandlerInternal common.ProcessStatusHandler HardforkTriggerPubKeyField []byte EnableEpochsHandlerField common.EnableEpochsHandler + ChainParametersHandlerField process.ChainParametersHandler + ChainParametersSubscriberField process.ChainParametersSubscriber + FieldsSizeCheckerField common.FieldsSizeChecker + EpochChangeGracePeriodHandlerField common.EpochChangeGracePeriodHandler } // Create - @@ -259,6 +264,26 @@ func (ccs *CoreComponentsStub) EnableEpochsHandler() common.EnableEpochsHandler return ccs.EnableEpochsHandlerField } +// ChainParametersHandler - +func (ccs *CoreComponentsStub) ChainParametersHandler() process.ChainParametersHandler { + return ccs.ChainParametersHandlerField +} + +// ChainParametersSubscriber - +func (ccs *CoreComponentsStub) ChainParametersSubscriber() process.ChainParametersSubscriber { + return ccs.ChainParametersSubscriberField +} + +// FieldsSizeChecker - +func (ccs *CoreComponentsStub) FieldsSizeChecker() common.FieldsSizeChecker { + return ccs.FieldsSizeCheckerField +} + +// EpochChangeGracePeriodHandler - +func (ccs *CoreComponentsStub) EpochChangeGracePeriodHandler() common.EpochChangeGracePeriodHandler { + return ccs.EpochChangeGracePeriodHandlerField +} + // IsInterfaceNil - func (ccs *CoreComponentsStub) IsInterfaceNil() bool { return ccs == nil diff --git a/integrationTests/mock/epochStartNotifier.go b/integrationTests/mock/epochStartNotifier.go index c4675a37401..8c6fb4c51e8 100644 --- a/integrationTests/mock/epochStartNotifier.go +++ b/integrationTests/mock/epochStartNotifier.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/integrationTests/mock/forkDetectorStub.go b/integrationTests/mock/forkDetectorStub.go index 950dd2b2e21..dba71b1dd38 100644 --- a/integrationTests/mock/forkDetectorStub.go +++ b/integrationTests/mock/forkDetectorStub.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" ) @@ -19,6 +20,8 @@ type ForkDetectorStub struct { SetRollBackNonceCalled func(nonce uint64) ResetProbableHighestNonceCalled func() SetFinalToLastCheckpointCalled func() + ReceivedProofCalled func(proof data.HeaderProofHandler) + AddCheckpointCalled func(nonce uint64, round uint64, hash []byte) } // RestoreToGenesis - @@ -114,6 +117,20 @@ func (fdm *ForkDetectorStub) SetFinalToLastCheckpoint() { } } +// ReceivedProof - +func (fdm *ForkDetectorStub) ReceivedProof(proof data.HeaderProofHandler) { + if fdm.ReceivedProofCalled != nil { + fdm.ReceivedProofCalled(proof) + } +} + +// AddCheckpoint - +func (fdm *ForkDetectorStub) AddCheckpoint(nonce uint64, round uint64, hash []byte) { + if fdm.AddCheckpointCalled != nil { + fdm.AddCheckpointCalled(nonce, round, hash) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (fdm *ForkDetectorStub) IsInterfaceNil() bool { return fdm == nil diff --git a/integrationTests/mock/headerSigVerifierStub.go b/integrationTests/mock/headerSigVerifierStub.go deleted file mode 100644 index b75b5615a12..00000000000 --- a/integrationTests/mock/headerSigVerifierStub.go +++ /dev/null @@ -1,52 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-core-go/data" - -// HeaderSigVerifierStub - -type HeaderSigVerifierStub struct { - VerifyRandSeedAndLeaderSignatureCalled func(header data.HeaderHandler) error - VerifySignatureCalled func(header data.HeaderHandler) error - VerifyRandSeedCalled func(header data.HeaderHandler) error - VerifyLeaderSignatureCalled func(header data.HeaderHandler) error -} - -// VerifyRandSeed - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeed(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedCalled != nil { - return hsvm.VerifyRandSeedCalled(header) - } - - return nil -} - -// VerifyRandSeedAndLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedAndLeaderSignatureCalled != nil { - return hsvm.VerifyRandSeedAndLeaderSignatureCalled(header) - } - - return nil -} - -// VerifySignature - -func (hsvm *HeaderSigVerifierStub) VerifySignature(header data.HeaderHandler) error { - if hsvm.VerifySignatureCalled != nil { - return hsvm.VerifySignatureCalled(header) - } - - return nil -} - -// VerifyLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyLeaderSignatureCalled != nil { - return hsvm.VerifyLeaderSignatureCalled(header) - } - - return nil -} - -// IsInterfaceNil - -func (hsvm *HeaderSigVerifierStub) IsInterfaceNil() bool { - return hsvm == nil -} diff --git a/integrationTests/mock/nilAntifloodHandler.go b/integrationTests/mock/nilAntifloodHandler.go index 868a2167767..fab73b964cd 100644 --- a/integrationTests/mock/nilAntifloodHandler.go +++ b/integrationTests/mock/nilAntifloodHandler.go @@ -13,6 +13,10 @@ import ( type NilAntifloodHandler struct { } +// SetConsensusSizeNotifier - +func (nah *NilAntifloodHandler) SetConsensusSizeNotifier(_ process.ChainParametersSubscriber, _ uint32) { +} + // ResetForTopic won't do anything func (nah *NilAntifloodHandler) ResetForTopic(_ string) { } diff --git a/integrationTests/mock/oneSCExecutorMockVM.go b/integrationTests/mock/oneSCExecutorMockVM.go index c7587eb976f..b6280635254 100644 --- a/integrationTests/mock/oneSCExecutorMockVM.go +++ b/integrationTests/mock/oneSCExecutorMockVM.go @@ -22,21 +22,22 @@ const getFunc = "get" var variableA = []byte("a") // OneSCExecutorMockVM contains one hardcoded SC with the following behaviour (written in golang): -//------------------------------------- +// ------------------------------------- // var a int // -// func init(initial int){ -// a = initial -// } +// func init(initial int){ +// a = initial +// } // -// func Add(value int){ -// a += value -// } +// func Add(value int){ +// a += value +// } // -// func Get() int{ -// return a -// } -//------------------------------------- +// func Get() int{ +// return a +// } +// +// ------------------------------------- type OneSCExecutorMockVM struct { blockchainHook vmcommon.BlockchainHook hasher hashing.Hasher diff --git a/integrationTests/mock/p2pAntifloodHandlerStub.go b/integrationTests/mock/p2pAntifloodHandlerStub.go index c181d10909d..3a9f89397b5 100644 --- a/integrationTests/mock/p2pAntifloodHandlerStub.go +++ b/integrationTests/mock/p2pAntifloodHandlerStub.go @@ -17,6 +17,7 @@ type P2PAntifloodHandlerStub struct { BlacklistPeerCalled func(peer core.PeerID, reason string, duration time.Duration) IsOriginatorEligibleForTopicCalled func(pid core.PeerID, topic string) error SetPeerValidatorMapperCalled func(validatorMapper process.PeerValidatorMapper) error + SetConsensusSizeNotifierCalled func(chainParametersNotifier process.ChainParametersSubscriber, shardID uint32) } // CanProcessMessage - @@ -50,6 +51,13 @@ func (stub *P2PAntifloodHandlerStub) ApplyConsensusSize(size int) { } } +// SetConsensusSizeNotifier - +func (p2pahs *P2PAntifloodHandlerStub) SetConsensusSizeNotifier(chainParametersNotifier process.ChainParametersSubscriber, shardID uint32) { + if p2pahs.SetConsensusSizeNotifierCalled != nil { + p2pahs.SetConsensusSizeNotifierCalled(chainParametersNotifier, shardID) + } +} + // SetDebugger - func (stub *P2PAntifloodHandlerStub) SetDebugger(debugger process.AntifloodDebugger) error { if stub.SetDebuggerCalled != nil { diff --git a/integrationTests/mock/roundHandlerMock.go b/integrationTests/mock/roundHandlerMock.go index 65a7ef5cc10..4234684ada7 100644 --- a/integrationTests/mock/roundHandlerMock.go +++ b/integrationTests/mock/roundHandlerMock.go @@ -19,6 +19,12 @@ func (mock *RoundHandlerMock) BeforeGenesis() bool { return false } +// RevertOneRound - +func (rndm *RoundHandlerMock) RevertOneRound() { + rndm.IndexField-- + rndm.TimeStampField = rndm.TimeStampField.Add(-rndm.TimeDurationField) +} + // Index - func (mock *RoundHandlerMock) Index() int64 { return mock.IndexField diff --git a/integrationTests/multiShard/block/common.go b/integrationTests/multiShard/block/common.go index e4fbd7403cc..481a7cf202a 100644 --- a/integrationTests/multiShard/block/common.go +++ b/integrationTests/multiShard/block/common.go @@ -2,28 +2,7 @@ package block import ( "time" - - "github.com/multiversx/mx-chain-go/integrationTests" ) // StepDelay - var StepDelay = time.Second / 10 - -// GetBlockProposersIndexes - -func GetBlockProposersIndexes( - consensusMap map[uint32][]*integrationTests.TestProcessorNode, - nodesMap map[uint32][]*integrationTests.TestProcessorNode, -) map[uint32]int { - - indexProposer := make(map[uint32]int) - - for sh, testNodeList := range nodesMap { - for k, testNode := range testNodeList { - if consensusMap[sh][0] == testNode { - indexProposer[sh] = k - } - } - } - - return indexProposer -} diff --git a/integrationTests/multiShard/block/edgecases/edgecases_test.go b/integrationTests/multiShard/block/edgecases/edgecases_test.go index 534cea84d31..6f041ee8609 100644 --- a/integrationTests/multiShard/block/edgecases/edgecases_test.go +++ b/integrationTests/multiShard/block/edgecases/edgecases_test.go @@ -9,12 +9,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-crypto-go" - "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/integrationTests/multiShard/block" - "github.com/multiversx/mx-chain-go/state" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/integrationTests" + "github.com/multiversx/mx-chain-go/integrationTests/multiShard/block" + "github.com/multiversx/mx-chain-go/state" ) var log = logger.GetOrCreate("integrationTests/multishard/block") @@ -23,14 +24,14 @@ var log = logger.GetOrCreate("integrationTests/multishard/block") // A validator from shard 0 receives rewards from shard 1 (where it is assigned) and creates move balance // transactions. All other shard peers can and will sync the blocks containing the move balance transactions. func TestExecutingTransactionsFromRewardsFundsCrossShard(t *testing.T) { - //TODO fix this test + // TODO fix this test t.Skip("TODO fix this test") if testing.Short() { t.Skip("this is not a short test") } - //it is important to have all combinations here as to test more edgecases + // it is important to have all combinations here as to test more edgecases mapAssignements := map[uint32][]uint32{ 0: {1, 0}, 1: {0, 1}, @@ -73,17 +74,14 @@ func TestExecutingTransactionsFromRewardsFundsCrossShard(t *testing.T) { firstNode := nodesMap[senderShardID][0] numBlocksProduced := uint64(13) - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode for i := uint64(0); i < numBlocksProduced; i++ { printAccount(firstNode) for _, nodes := range nodesMap { integrationTests.UpdateRound(nodes, round) } - _, _, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - - indexesProposers := block.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + proposalData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + integrationTests.SyncAllShardsWithRoundBlock(t, proposalData, nodesMap, round) time.Sleep(block.StepDelay) round++ @@ -132,7 +130,7 @@ func TestMetaShouldBeAbleToProduceBlockInAVeryHighRoundAndStartOfEpoch(t *testin } } - //edge case on the epoch change + // edge case on the epoch change round := roundsPerEpoch*10 - 1 nonce := uint64(1) round = integrationTests.IncrementAndPrintRound(round) @@ -141,9 +139,8 @@ func TestMetaShouldBeAbleToProduceBlockInAVeryHighRoundAndStartOfEpoch(t *testin integrationTests.UpdateRound(nodes, round) } - _, _, consensusNodes := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - indexesProposers := block.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, nonce) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, nonce) for _, nodes := range nodesMap { for _, node := range nodes { @@ -163,14 +160,14 @@ func closeNodes(nodesMap map[uint32][]*integrationTests.TestProcessorNode) { } } -//nolint +// nolint func checkSameBlockHeight(t *testing.T, nodesMap map[uint32][]*integrationTests.TestProcessorNode) { for _, nodes := range nodesMap { referenceBlock := nodes[0].BlockChain.GetCurrentBlockHeader() for _, n := range nodes { crtBlock := n.BlockChain.GetCurrentBlockHeader() - //(crtBlock == nil) != (blkc == nil) actually does a XOR operation between the 2 conditions - //as if the reference is nil, the same must be all other nodes. Same if the reference is not nil. + // (crtBlock == nil) != (blkc == nil) actually does a XOR operation between the 2 conditions + // as if the reference is nil, the same must be all other nodes. Same if the reference is not nil. require.False(t, (referenceBlock == nil) != (crtBlock == nil)) if !check.IfNil(referenceBlock) { require.Equal(t, referenceBlock.GetNonce(), crtBlock.GetNonce()) @@ -179,7 +176,7 @@ func checkSameBlockHeight(t *testing.T, nodesMap map[uint32][]*integrationTests. } } -//nolint +// nolint func printAccount(node *integrationTests.TestProcessorNode) { accnt, _ := node.AccntState.GetExistingAccount(node.OwnAccount.Address) if check.IfNil(accnt) { diff --git a/integrationTests/multiShard/block/executingMiniblocks/executingMiniblocks_test.go b/integrationTests/multiShard/block/executingMiniblocks/executingMiniblocks_test.go index eec61878296..fcf5ec9178c 100644 --- a/integrationTests/multiShard/block/executingMiniblocks/executingMiniblocks_test.go +++ b/integrationTests/multiShard/block/executingMiniblocks/executingMiniblocks_test.go @@ -33,7 +33,6 @@ func TestShouldProcessBlocksInMultiShardArchitecture(t *testing.T) { nodesPerShard := 3 numMetachainNodes := 1 - idxProposers := []int{0, 3, 6, 9, 12, 15, 18} senderShard := uint32(0) recvShards := []uint32{1, 2} round := uint64(0) @@ -47,6 +46,7 @@ func TestShouldProcessBlocksInMultiShardArchitecture(t *testing.T) { nodesPerShard, numMetachainNodes, ) + leaders := []*integrationTests.TestProcessorNode{nodes[0], nodes[3], nodes[6], nodes[9], nodes[12], nodes[15], nodes[18]} integrationTests.DisplayAndStartNodes(nodes) defer func() { @@ -97,7 +97,7 @@ func TestShouldProcessBlocksInMultiShardArchitecture(t *testing.T) { nonce++ roundsToWait := 6 for i := 0; i < roundsToWait; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) } gasPricePerTxBigInt := big.NewInt(0).SetUint64(integrationTests.MinTxGasPrice) @@ -163,11 +163,11 @@ func TestSimpleTransactionsWithMoreGasWhichYieldInReceiptsInMultiShardedEnvironm node.EconomicsData.SetMinGasLimit(minGasLimit, 0) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -192,8 +192,8 @@ func TestSimpleTransactionsWithMoreGasWhichYieldInReceiptsInMultiShardedEnvironm nrRoundsToTest := 10 for i := 0; i <= nrRoundsToTest; i++ { integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -253,11 +253,11 @@ func TestSimpleTransactionsWithMoreValueThanBalanceYieldReceiptsInMultiShardedEn node.EconomicsData.SetMinGasLimit(minGasLimit, 0) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -294,8 +294,8 @@ func TestSimpleTransactionsWithMoreValueThanBalanceYieldReceiptsInMultiShardedEn time.Sleep(2 * time.Second) integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -320,8 +320,8 @@ func TestSimpleTransactionsWithMoreValueThanBalanceYieldReceiptsInMultiShardedEn numRoundsToTest := 6 for i := 0; i < numRoundsToTest; i++ { integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -420,22 +420,22 @@ func TestShouldSubtractTheCorrectTxFee(t *testing.T) { gasPrice, ) - _, _, consensusNodes := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) shardId0 := uint32(0) _ = integrationTests.IncrementAndPrintRound(round) // test sender account decreased its balance with gasPrice * gasLimit - accnt, err := consensusNodes[shardId0][0].AccntState.GetExistingAccount(ownerPk) + accnt, err := proposeData[shardId0].Leader.AccntState.GetExistingAccount(ownerPk) assert.Nil(t, err) ownerAccnt := accnt.(state.UserAccountHandler) expectedBalance := big.NewInt(0).Set(initialVal) tx := &transaction.Transaction{GasPrice: gasPrice, GasLimit: gasLimit, Data: []byte(txData)} - txCost := consensusNodes[shardId0][0].EconomicsData.ComputeTxFee(tx) + txCost := proposeData[shardId0].Leader.EconomicsData.ComputeTxFee(tx) expectedBalance.Sub(expectedBalance, txCost) assert.Equal(t, expectedBalance, ownerAccnt.GetBalance()) - printContainingTxs(consensusNodes[shardId0][0], consensusNodes[shardId0][0].BlockChain.GetCurrentBlockHeader().(*block.Header)) + printContainingTxs(proposeData[shardId0].Leader, proposeData[shardId0].Leader.BlockChain.GetCurrentBlockHeader().(*block.Header)) } func printContainingTxs(tpn *integrationTests.TestProcessorNode, hdr data.HeaderHandler) { diff --git a/integrationTests/multiShard/block/executingRewardMiniblocks/executingRewardMiniblocks_test.go b/integrationTests/multiShard/block/executingRewardMiniblocks/executingRewardMiniblocks_test.go index 38822aa6427..740ebb1bef2 100644 --- a/integrationTests/multiShard/block/executingRewardMiniblocks/executingRewardMiniblocks_test.go +++ b/integrationTests/multiShard/block/executingRewardMiniblocks/executingRewardMiniblocks_test.go @@ -10,11 +10,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/integrationTests" - testBlock "github.com/multiversx/mx-chain-go/integrationTests/multiShard/block" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" - "github.com/stretchr/testify/assert" ) func getLeaderPercentage(node *integrationTests.TestProcessorNode, epoch uint32) float64 { @@ -66,8 +66,6 @@ func TestExecuteBlocksWithTransactionsAndCheckRewards(t *testing.T) { nonce := uint64(1) nbBlocksProduced := 7 - var headers map[uint32]data.HeaderHandler - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode mapRewardsForShardAddresses := make(map[string]uint32) mapRewardsForMetachainAddresses := make(map[string]uint32) nbTxsForLeaderAddress := make(map[string]uint32) @@ -76,21 +74,18 @@ func TestExecuteBlocksWithTransactionsAndCheckRewards(t *testing.T) { for _, nodes := range nodesMap { integrationTests.UpdateRound(nodes, round) } - _, headers, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - for shardId := range consensusNodes { + for shardId := range proposeData { addrRewards := make([]string, 0) updateExpectedRewards(mapRewardsForShardAddresses, addrRewards) - nbTxs := getTransactionsFromHeaderInShard(t, headers, shardId) + nbTxs := getTransactionsFromHeaderInShard(t, proposeData[shardId].Header, shardId) if len(addrRewards) > 0 { updateNumberTransactionsProposed(t, nbTxsForLeaderAddress, addrRewards[0], nbTxs) } } - updateRewardsForMetachain(mapRewardsForMetachainAddresses, consensusNodes[0][0]) - - indexesProposers := testBlock.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) time.Sleep(integrationTests.StepDelay) @@ -149,18 +144,16 @@ func TestExecuteBlocksWithTransactionsWhichReachedGasLimitAndCheckRewards(t *tes nonce := uint64(1) nbBlocksProduced := 2 - var headers map[uint32]data.HeaderHandler - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode mapRewardsForShardAddresses := make(map[string]uint32) nbTxsForLeaderAddress := make(map[string]uint32) for i := 0; i < nbBlocksProduced; i++ { - _, headers, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - for shardId := range consensusNodes { + for shardId := range nodesMap { addrRewards := make([]string, 0) updateExpectedRewards(mapRewardsForShardAddresses, addrRewards) - nbTxs := getTransactionsFromHeaderInShard(t, headers, shardId) + nbTxs := getTransactionsFromHeaderInShard(t, proposeData[shardId].Header, shardId) if len(addrRewards) > 0 { updateNumberTransactionsProposed(t, nbTxsForLeaderAddress, addrRewards[0], nbTxs) } @@ -169,8 +162,7 @@ func TestExecuteBlocksWithTransactionsWhichReachedGasLimitAndCheckRewards(t *tes for _, nodes := range nodesMap { integrationTests.UpdateRound(nodes, round) } - indexesProposers := testBlock.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) round++ nonce++ } @@ -213,15 +205,14 @@ func TestExecuteBlocksWithoutTransactionsAndCheckRewards(t *testing.T) { nonce := uint64(1) nbBlocksProduced := 7 - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode mapRewardsForShardAddresses := make(map[string]uint32) mapRewardsForMetachainAddresses := make(map[string]uint32) nbTxsForLeaderAddress := make(map[string]uint32) for i := 0; i < nbBlocksProduced; i++ { - _, _, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - for shardId := range consensusNodes { + for shardId := range nodesMap { if shardId == core.MetachainShardId { continue } @@ -231,13 +222,10 @@ func TestExecuteBlocksWithoutTransactionsAndCheckRewards(t *testing.T) { updateExpectedRewards(mapRewardsForShardAddresses, addrRewards) } - updateRewardsForMetachain(mapRewardsForMetachainAddresses, consensusNodes[0][0]) - for _, nodes := range nodesMap { integrationTests.UpdateRound(nodes, round) } - indexesProposers := testBlock.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) round++ nonce++ } @@ -248,16 +236,11 @@ func TestExecuteBlocksWithoutTransactionsAndCheckRewards(t *testing.T) { verifyRewardsForMetachain(t, mapRewardsForMetachainAddresses, nodesMap) } -func getTransactionsFromHeaderInShard(t *testing.T, headers map[uint32]data.HeaderHandler, shardId uint32) uint32 { +func getTransactionsFromHeaderInShard(t *testing.T, header data.HeaderHandler, shardId uint32) uint32 { if shardId == core.MetachainShardId { return 0 } - header, ok := headers[shardId] - if !ok { - return 0 - } - hdr, ok := header.(*block.Header) if !ok { assert.Error(t, process.ErrWrongTypeAssertion) @@ -296,9 +279,6 @@ func updateNumberTransactionsProposed( transactionsForLeader[addressProposer] += nbTransactions } -func updateRewardsForMetachain(_ map[string]uint32, _ *integrationTests.TestProcessorNode) { -} - func verifyRewardsForMetachain( t *testing.T, mapRewardsForMeta map[string]uint32, diff --git a/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go b/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go index 82eca349947..099864c1dc8 100644 --- a/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go +++ b/integrationTests/multiShard/block/interceptedHeadersSigVerification/interceptedHeadersSigVerification_test.go @@ -11,8 +11,9 @@ import ( "github.com/multiversx/mx-chain-crypto-go" "github.com/multiversx/mx-chain-crypto-go/signing" "github.com/multiversx/mx-chain-crypto-go/signing/mcl" - "github.com/multiversx/mx-chain-go/integrationTests" "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/integrationTests" ) const broadcastDelay = 2 * time.Second @@ -57,12 +58,12 @@ func TestInterceptedShardBlockHeaderVerifiedWithCorrectConsensusGroup(t *testing nonce := uint64(1) var err error - body, header, _, _ := integrationTests.ProposeBlockWithConsensusSignature(0, nodesMap, round, nonce, randomness, 0) - header, err = fillHeaderFields(nodesMap[0][0], header, singleSigner) + proposeBlockData := integrationTests.ProposeBlockWithConsensusSignature(0, nodesMap, round, nonce, randomness, 0) + header, err := fillHeaderFields(proposeBlockData.Leader, proposeBlockData.Header, singleSigner) assert.Nil(t, err) pk := nodesMap[0][0].NodeKeys.MainKey.Pk - nodesMap[0][0].BroadcastBlock(body, header, pk) + nodesMap[0][0].BroadcastBlock(proposeBlockData.Body, header, pk) time.Sleep(broadcastDelay) @@ -122,7 +123,7 @@ func TestInterceptedMetaBlockVerifiedWithCorrectConsensusGroup(t *testing.T) { round := uint64(1) nonce := uint64(1) - body, header, _, _ := integrationTests.ProposeBlockWithConsensusSignature( + proposeBlockData := integrationTests.ProposeBlockWithConsensusSignature( core.MetachainShardId, nodesMap, round, @@ -132,13 +133,13 @@ func TestInterceptedMetaBlockVerifiedWithCorrectConsensusGroup(t *testing.T) { ) pk := nodesMap[core.MetachainShardId][0].NodeKeys.MainKey.Pk - nodesMap[core.MetachainShardId][0].BroadcastBlock(body, header, pk) + nodesMap[core.MetachainShardId][0].BroadcastBlock(proposeBlockData.Body, proposeBlockData.Header, pk) time.Sleep(broadcastDelay) - headerBytes, _ := integrationTests.TestMarshalizer.Marshal(header) + headerBytes, _ := integrationTests.TestMarshalizer.Marshal(proposeBlockData.Header) headerHash := integrationTests.TestHasher.Compute(string(headerBytes)) - hmb := header.(*block.MetaBlock) + hmb := proposeBlockData.Header.(*block.MetaBlock) // all nodes in metachain do not have the block in pool as interceptor does not validate it with a wrong consensus for _, metaNode := range nodesMap[core.MetachainShardId] { @@ -197,16 +198,16 @@ func TestInterceptedShardBlockHeaderWithLeaderSignatureAndRandSeedChecks(t *test round := uint64(1) nonce := uint64(1) - body, header, _, consensusNodes := integrationTests.ProposeBlockWithConsensusSignature(0, nodesMap, round, nonce, randomness, 0) - nodeToSendFrom := consensusNodes[0] - err := header.SetPrevRandSeed(randomness) + proposeBlockData := integrationTests.ProposeBlockWithConsensusSignature(0, nodesMap, round, nonce, randomness, 0) + nodeToSendFrom := proposeBlockData.Leader + err := proposeBlockData.Header.SetPrevRandSeed(randomness) assert.Nil(t, err) - header, err = fillHeaderFields(nodeToSendFrom, header, singleSigner) + header, err := fillHeaderFields(nodeToSendFrom, proposeBlockData.Header, singleSigner) assert.Nil(t, err) pk := nodeToSendFrom.NodeKeys.MainKey.Pk - nodeToSendFrom.BroadcastBlock(body, header, pk) + nodeToSendFrom.BroadcastBlock(proposeBlockData.Body, header, pk) time.Sleep(broadcastDelay) @@ -268,14 +269,14 @@ func TestInterceptedShardHeaderBlockWithWrongPreviousRandSeedShouldNotBeAccepted wrongRandomness := []byte("wrong randomness") round := uint64(2) nonce := uint64(2) - body, header, _, _ := integrationTests.ProposeBlockWithConsensusSignature(0, nodesMap, round, nonce, wrongRandomness, 0) + proposeBlockData := integrationTests.ProposeBlockWithConsensusSignature(0, nodesMap, round, nonce, wrongRandomness, 0) pk := nodesMap[0][0].NodeKeys.MainKey.Pk - nodesMap[0][0].BroadcastBlock(body, header, pk) + nodesMap[0][0].BroadcastBlock(proposeBlockData.Body, proposeBlockData.Header, pk) time.Sleep(broadcastDelay) - headerBytes, _ := integrationTests.TestMarshalizer.Marshal(header) + headerBytes, _ := integrationTests.TestMarshalizer.Marshal(proposeBlockData.Header) headerHash := integrationTests.TestHasher.Compute(string(headerBytes)) // all nodes in metachain have the block header in pool as interceptor validates it @@ -294,8 +295,11 @@ func TestInterceptedShardHeaderBlockWithWrongPreviousRandSeedShouldNotBeAccepted func fillHeaderFields(proposer *integrationTests.TestProcessorNode, hdr data.HeaderHandler, signer crypto.SingleSigner) (data.HeaderHandler, error) { leaderSk := proposer.NodeKeys.MainKey.Sk - randSeed, _ := signer.Sign(leaderSk, hdr.GetPrevRandSeed()) - err := hdr.SetRandSeed(randSeed) + randSeed, err := signer.Sign(leaderSk, hdr.GetPrevRandSeed()) + if err != nil { + return nil, err + } + err = hdr.SetRandSeed(randSeed) if err != nil { return nil, err } diff --git a/integrationTests/multiShard/endOfEpoch/common.go b/integrationTests/multiShard/endOfEpoch/common.go index c416479849d..4d3a6673703 100644 --- a/integrationTests/multiShard/endOfEpoch/common.go +++ b/integrationTests/multiShard/endOfEpoch/common.go @@ -6,9 +6,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/stretchr/testify/assert" ) // CreateAndPropagateBlocks - @@ -18,12 +19,12 @@ func CreateAndPropagateBlocks( currentRound uint64, currentNonce uint64, nodes []*integrationTests.TestProcessorNode, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, ) (uint64, uint64) { for i := uint64(0); i <= nbRounds; i++ { integrationTests.UpdateRound(nodes, currentRound) - integrationTests.ProposeBlock(nodes, idxProposers, currentRound, currentNonce) - integrationTests.SyncBlock(t, nodes, idxProposers, currentRound) + integrationTests.ProposeBlock(nodes, leaders, currentRound, currentNonce) + integrationTests.SyncBlock(t, nodes, leaders, currentRound) currentRound = integrationTests.IncrementAndPrintRound(currentRound) currentNonce++ } diff --git a/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShuffling/epochChangeWithNodesShuffling_test.go b/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShuffling/epochChangeWithNodesShuffling_test.go index a2b5846a759..a3d08fbd755 100644 --- a/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShuffling/epochChangeWithNodesShuffling_test.go +++ b/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShuffling/epochChangeWithNodesShuffling_test.go @@ -58,16 +58,14 @@ func TestEpochChangeWithNodesShuffling(t *testing.T) { nonce := uint64(1) nbBlocksToProduce := uint64(20) expectedLastEpoch := uint32(nbBlocksToProduce / roundsPerEpoch) - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode for i := uint64(0); i < nbBlocksToProduce; i++ { for _, nodes := range nodesMap { integrationTests.UpdateRound(nodes, round) } - _, _, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - indexesProposers := endOfEpoch.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) round++ nonce++ diff --git a/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShufflingAndRater/epochChangeWithNodesShufflingAndRater_test.go b/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShufflingAndRater/epochChangeWithNodesShufflingAndRater_test.go index 9c81ff6e97e..59c0abc5156 100644 --- a/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShufflingAndRater/epochChangeWithNodesShufflingAndRater_test.go +++ b/integrationTests/multiShard/endOfEpoch/epochChangeWithNodesShufflingAndRater/epochChangeWithNodesShufflingAndRater_test.go @@ -5,10 +5,11 @@ import ( "testing" "time" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/multiShard/endOfEpoch" "github.com/multiversx/mx-chain-go/process/rating" - logger "github.com/multiversx/mx-chain-logger-go" ) func TestEpochChangeWithNodesShufflingAndRater(t *testing.T) { @@ -68,16 +69,14 @@ func TestEpochChangeWithNodesShufflingAndRater(t *testing.T) { nonce := uint64(1) nbBlocksToProduce := uint64(20) expectedLastEpoch := uint32(nbBlocksToProduce / roundsPerEpoch) - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode for i := uint64(0); i < nbBlocksToProduce; i++ { for _, nodes := range nodesMap { integrationTests.UpdateRound(nodes, round) } - _, _, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - indexesProposers := endOfEpoch.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) round++ nonce++ diff --git a/integrationTests/multiShard/endOfEpoch/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment_test.go b/integrationTests/multiShard/endOfEpoch/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment_test.go index dd964aeb745..f8bde5fb75d 100644 --- a/integrationTests/multiShard/endOfEpoch/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment_test.go +++ b/integrationTests/multiShard/endOfEpoch/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment/epochStartChangeWithContinuousTransactionsInMultiShardedEnvironment_test.go @@ -26,6 +26,7 @@ func TestEpochStartChangeWithContinuousTransactionsInMultiShardedEnvironment(t * StakingV4Step1EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step2EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step3EnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -40,11 +41,11 @@ func TestEpochStartChangeWithContinuousTransactionsInMultiShardedEnvironment(t * node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -71,8 +72,8 @@ func TestEpochStartChangeWithContinuousTransactionsInMultiShardedEnvironment(t * nrRoundsToPropagateMultiShard := uint64(5) for i := uint64(0); i <= (uint64(epoch)*roundsPerEpoch)+nrRoundsToPropagateMultiShard; i++ { integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ diff --git a/integrationTests/multiShard/endOfEpoch/epochStartChangeWithoutTransactionInMultiShardedEnvironment/epochStartChangeWithoutTransactionInMultiShardedEnvironment_test.go b/integrationTests/multiShard/endOfEpoch/epochStartChangeWithoutTransactionInMultiShardedEnvironment/epochStartChangeWithoutTransactionInMultiShardedEnvironment_test.go index d14eb086de6..92e626602a9 100644 --- a/integrationTests/multiShard/endOfEpoch/epochStartChangeWithoutTransactionInMultiShardedEnvironment/epochStartChangeWithoutTransactionInMultiShardedEnvironment_test.go +++ b/integrationTests/multiShard/endOfEpoch/epochStartChangeWithoutTransactionInMultiShardedEnvironment/epochStartChangeWithoutTransactionInMultiShardedEnvironment_test.go @@ -25,6 +25,7 @@ func TestEpochStartChangeWithoutTransactionInMultiShardedEnvironment(t *testing. StakingV4Step1EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step2EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step3EnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -39,11 +40,11 @@ func TestEpochStartChangeWithoutTransactionInMultiShardedEnvironment(t *testing. node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -61,10 +62,10 @@ func TestEpochStartChangeWithoutTransactionInMultiShardedEnvironment(t *testing. time.Sleep(time.Second) // ----- wait for epoch end period - round, nonce = endOfEpoch.CreateAndPropagateBlocks(t, roundsPerEpoch, round, nonce, nodes, idxProposers) + round, nonce = endOfEpoch.CreateAndPropagateBlocks(t, roundsPerEpoch, round, nonce, nodes, leaders) nrRoundsToPropagateMultiShard := uint64(5) - _, _ = endOfEpoch.CreateAndPropagateBlocks(t, nrRoundsToPropagateMultiShard, round, nonce, nodes, idxProposers) + _, _ = endOfEpoch.CreateAndPropagateBlocks(t, nrRoundsToPropagateMultiShard, round, nonce, nodes, leaders) epoch := uint32(1) endOfEpoch.VerifyThatNodesHaveCorrectEpoch(t, epoch, nodes) diff --git a/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go b/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go index ce933a22666..8ac99dad211 100644 --- a/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go +++ b/integrationTests/multiShard/endOfEpoch/startInEpoch/startInEpoch_test.go @@ -11,7 +11,12 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/endProcess" "github.com/multiversx/mx-chain-core-go/data/typeConverters/uint64ByteSlice" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/common/statistics/disabled" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -24,6 +29,7 @@ import ( "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/bootstrapStorage" "github.com/multiversx/mx-chain-go/process/block/pendingMb" + interceptorsFactory "github.com/multiversx/mx-chain-go/process/interceptors/factory" "github.com/multiversx/mx-chain-go/process/smartContract" "github.com/multiversx/mx-chain-go/process/sync/storageBootstrap" "github.com/multiversx/mx-chain-go/sharding" @@ -31,7 +37,8 @@ import ( "github.com/multiversx/mx-chain-go/storage/factory" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" - epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" @@ -40,7 +47,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/scheduledDataSyncer" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) func TestStartInEpochForAShardNodeInMultiShardedEnvironment(t *testing.T) { @@ -72,6 +78,7 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui StakingV4Step1EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step2EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step3EnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -86,11 +93,11 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * numNodesPerShard + leaders[i] = nodes[i*numNodesPerShard] } - idxProposers[numOfShards] = numOfShards * numNodesPerShard + leaders[numOfShards] = nodes[numOfShards*numNodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -117,8 +124,8 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui nrRoundsToPropagateMultiShard := uint64(5) for i := uint64(0); i <= (uint64(epoch)*roundsPerEpoch)+nrRoundsToPropagateMultiShard; i++ { integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -221,7 +228,9 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui cryptoComponents.BlKeyGen = &mock.KeyGenMock{} cryptoComponents.TxKeyGen = &mock.KeyGenMock{} - coreComponents := integrationTests.GetDefaultCoreComponents(integrationTests.CreateEnableEpochsConfig()) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(enableEpochsConfig, genericEpochNotifier) + coreComponents := integrationTests.GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.InternalMarshalizerField = integrationTests.TestMarshalizer coreComponents.TxMarshalizerField = integrationTests.TestMarshalizer coreComponents.HasherField = integrationTests.TestHasher @@ -234,11 +243,16 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui coreComponents.NodeTypeProviderField = &nodeTypeProviderMock.NodeTypeProviderStub{} coreComponents.ChanStopNodeProcessField = endProcess.GetDummyEndProcessChannel() coreComponents.HardforkTriggerPubKeyField = []byte("provided hardfork pub key") + coreComponents.ChainParametersHandlerField = &chainParameters.ChainParametersHandlerStub{} nodesCoordinatorRegistryFactory, _ := nodesCoordinator.NewNodesCoordinatorRegistryFactory( &marshallerMock.MarshalizerMock{}, 444, ) + interceptorDataVerifierArgs := interceptorsFactory.InterceptedDataVerifierFactoryArgs{ + CacheSpan: time.Second * 5, + CacheExpiry: time.Second * 10, + } argsBootstrapHandler := bootstrap.ArgsEpochStartBootstrap{ NodesCoordinatorRegistryFactory: nodesCoordinatorRegistryFactory, CryptoComponentsHolder: cryptoComponents, @@ -277,8 +291,10 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui FlagsConfig: config.ContextFlagsConfig{ ForceStartFromNetwork: false, }, - TrieSyncStatisticsProvider: &testscommon.SizeSyncStatisticsHandlerStub{}, - StateStatsHandler: disabled.NewStateStatistics(), + TrieSyncStatisticsProvider: &testscommon.SizeSyncStatisticsHandlerStub{}, + StateStatsHandler: disabled.NewStateStatistics(), + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + InterceptedDataVerifierFactory: interceptorsFactory.NewInterceptedDataVerifierFactory(interceptorDataVerifierArgs), } epochStartBootstrap, err := bootstrap.NewEpochStartBootstrap(argsBootstrapHandler) @@ -345,9 +361,11 @@ func testNodeStartsInEpoch(t *testing.T, shardID uint32, expectedHighestRound ui ChainID: string(integrationTests.ChainID), ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, MiniblocksProvider: &mock.MiniBlocksProviderStub{}, - EpochNotifier: &epochNotifierMock.EpochNotifierStub{}, + EpochNotifier: genericEpochNotifier, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, AppStatusHandler: &statusHandlerMock.AppStatusHandlerMock{}, + EnableEpochsHandler: enableEpochsHandler, + ProofsPool: &dataRetrieverMocks.ProofsPoolMock{}, } bootstrapper, err := getBootstrapper(shardID, argsBaseBootstrapper) diff --git a/integrationTests/multiShard/hardFork/hardFork_test.go b/integrationTests/multiShard/hardFork/hardFork_test.go index 81c9e4652f4..5b2754110ef 100644 --- a/integrationTests/multiShard/hardFork/hardFork_test.go +++ b/integrationTests/multiShard/hardFork/hardFork_test.go @@ -12,6 +12,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" + logger "github.com/multiversx/mx-chain-logger-go" + wasmConfig "github.com/multiversx/mx-chain-vm-go/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/common/statistics/disabled" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -20,6 +27,7 @@ import ( "github.com/multiversx/mx-chain-go/integrationTests/mock" "github.com/multiversx/mx-chain-go/integrationTests/vm/wasm" vmFactory "github.com/multiversx/mx-chain-go/process/factory" + interceptorFactory "github.com/multiversx/mx-chain-go/process/interceptors/factory" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" commonMocks "github.com/multiversx/mx-chain-go/testscommon/common" @@ -31,10 +39,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/update/factory" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" - logger "github.com/multiversx/mx-chain-logger-go" - wasmConfig "github.com/multiversx/mx-chain-vm-go/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var log = logger.GetOrCreate("integrationTests/hardfork") @@ -64,11 +68,11 @@ func TestHardForkWithoutTransactionInMultiShardedEnvironment(t *testing.T) { node.WaitTime = 100 * time.Millisecond } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -89,11 +93,11 @@ func TestHardForkWithoutTransactionInMultiShardedEnvironment(t *testing.T) { nrRoundsToPropagateMultiShard := 5 // ----- wait for epoch end period - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, int(roundsPerEpoch), nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, int(roundsPerEpoch), nonce, round) time.Sleep(time.Second) - nonce, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) @@ -135,11 +139,11 @@ func TestHardForkWithContinuousTransactionsInMultiShardedEnvironment(t *testing. node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -189,7 +193,7 @@ func TestHardForkWithContinuousTransactionsInMultiShardedEnvironment(t *testing. epoch := uint32(2) nrRoundsToPropagateMultiShard := uint64(6) for i := uint64(0); i <= (uint64(epoch)*roundsPerEpoch)+nrRoundsToPropagateMultiShard; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) for _, node := range nodes { integrationTests.CreateAndSendTransaction(node, nodes, sendValue, receiverAddress1, "", integrationTests.AdditionalGasLimit) @@ -253,11 +257,11 @@ func TestHardForkEarlyEndOfEpochWithContinuousTransactionsInMultiShardedEnvironm node.EpochStartTrigger.SetMinRoundsBetweenEpochs(minRoundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = allNodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = allNodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(allNodes) @@ -310,7 +314,7 @@ func TestHardForkEarlyEndOfEpochWithContinuousTransactionsInMultiShardedEnvironm log.LogIfError(err) } - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, consensusNodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, consensusNodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(consensusNodes) for _, node := range consensusNodes { integrationTests.CreateAndSendTransaction(node, allNodes, sendValue, receiverAddress1, "", integrationTests.AdditionalGasLimit) @@ -388,7 +392,9 @@ func hardForkImport( defaults.FillGasMapInternal(gasSchedule, 1) log.Warn("started import process") - coreComponents := integrationTests.GetDefaultCoreComponents(integrationTests.CreateEnableEpochsConfig()) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(integrationTests.CreateEnableEpochsConfig(), genericEpochNotifier) + coreComponents := integrationTests.GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.InternalMarshalizerField = integrationTests.TestMarshalizer coreComponents.TxMarshalizerField = integrationTests.TestMarshalizer coreComponents.HasherField = integrationTests.TestHasher @@ -569,7 +575,9 @@ func createHardForkExporter( returnedConfigs[node.ShardCoordinator.SelfId()] = append(returnedConfigs[node.ShardCoordinator.SelfId()], exportConfig) returnedConfigs[node.ShardCoordinator.SelfId()] = append(returnedConfigs[node.ShardCoordinator.SelfId()], keysConfig) - coreComponents := integrationTests.GetDefaultCoreComponents(integrationTests.CreateEnableEpochsConfig()) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(integrationTests.CreateEnableEpochsConfig(), genericEpochNotifier) + coreComponents := integrationTests.GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.InternalMarshalizerField = integrationTests.TestMarshalizer coreComponents.TxMarshalizerField = integrationTests.TestTxSignMarshalizer coreComponents.HasherField = integrationTests.TestHasher @@ -601,6 +609,11 @@ func createHardForkExporter( networkComponents.PeersRatingHandlerField = node.PeersRatingHandler networkComponents.InputAntiFlood = &mock.NilAntifloodHandler{} networkComponents.OutputAntiFlood = &mock.NilAntifloodHandler{} + + interceptorDataVerifierFactoryArgs := interceptorFactory.InterceptedDataVerifierFactoryArgs{ + CacheSpan: time.Second * 5, + CacheExpiry: time.Second * 10, + } argsExportHandler := factory.ArgsExporter{ CoreComponents: coreComponents, CryptoComponents: cryptoComponents, @@ -650,11 +663,12 @@ func createHardForkExporter( NumResolveFailureThreshold: 3, DebugLineExpiration: 3, }, - MaxHardCapForMissingNodes: 500, - NumConcurrentTrieSyncers: 50, - TrieSyncerVersion: 2, - CheckNodesOnDisk: false, - NodeOperationMode: node.NodeOperationMode, + MaxHardCapForMissingNodes: 500, + NumConcurrentTrieSyncers: 50, + TrieSyncerVersion: 2, + CheckNodesOnDisk: false, + NodeOperationMode: node.NodeOperationMode, + InterceptedDataVerifierFactory: interceptorFactory.NewInterceptedDataVerifierFactory(interceptorDataVerifierFactoryArgs), } exportHandler, err := factory.NewExportHandlerFactory(argsExportHandler) diff --git a/integrationTests/multiShard/relayedTx/common.go b/integrationTests/multiShard/relayedTx/common.go index c440b574d8c..9ffa22aa079 100644 --- a/integrationTests/multiShard/relayedTx/common.go +++ b/integrationTests/multiShard/relayedTx/common.go @@ -17,21 +17,21 @@ import ( var log = logger.GetOrCreate("relayedtests") // CreateGeneralSetupForRelayTxTest will create the general setup for relayed transactions -func CreateGeneralSetupForRelayTxTest(baseCostFixEnabled bool) ([]*integrationTests.TestProcessorNode, []int, []*integrationTests.TestWalletAccount, *integrationTests.TestWalletAccount) { +func CreateGeneralSetupForRelayTxTest(baseCostFixEnabled bool) ([]*integrationTests.TestProcessorNode, []*integrationTests.TestProcessorNode, []*integrationTests.TestWalletAccount, *integrationTests.TestWalletAccount) { initialVal := big.NewInt(10000000000) epochsConfig := integrationTests.GetDefaultEnableEpochsConfig() if !baseCostFixEnabled { epochsConfig.FixRelayedBaseCostEnableEpoch = integrationTests.UnreachableEpoch epochsConfig.FixRelayedMoveBalanceToNonPayableSCEnableEpoch = integrationTests.UnreachableEpoch } - nodes, idxProposers := createAndMintNodes(initialVal, epochsConfig) + nodes, leaders := createAndMintNodes(initialVal, epochsConfig) players, relayerAccount := createAndMintPlayers(baseCostFixEnabled, nodes, initialVal) - return nodes, idxProposers, players, relayerAccount + return nodes, leaders, players, relayerAccount } -func createAndMintNodes(initialVal *big.Int, enableEpochsConfig *config.EnableEpochs) ([]*integrationTests.TestProcessorNode, []int) { +func createAndMintNodes(initialVal *big.Int, enableEpochsConfig *config.EnableEpochs) ([]*integrationTests.TestProcessorNode, []*integrationTests.TestProcessorNode) { numOfShards := 2 nodesPerShard := 2 numMetachainNodes := 1 @@ -43,17 +43,17 @@ func createAndMintNodes(initialVal *big.Int, enableEpochsConfig *config.EnableEp enableEpochsConfig, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) integrationTests.MintAllNodes(nodes, initialVal) - return nodes, idxProposers + return nodes, leaders } func createAndMintPlayers( diff --git a/integrationTests/multiShard/relayedTx/edgecases/edgecases_test.go b/integrationTests/multiShard/relayedTx/edgecases/edgecases_test.go index 72e7bafda2e..51660323333 100644 --- a/integrationTests/multiShard/relayedTx/edgecases/edgecases_test.go +++ b/integrationTests/multiShard/relayedTx/edgecases/edgecases_test.go @@ -16,7 +16,7 @@ func TestRelayedTransactionInMultiShardEnvironmentWithNormalTxButWrongNonceShoul t.Skip("this is not a short test") } - nodes, idxProposers, players, relayer := relayedTx.CreateGeneralSetupForRelayTxTest(false) + nodes, leaders, players, relayer := relayedTx.CreateGeneralSetupForRelayTxTest(false) defer func() { for _, n := range nodes { n.Close() @@ -40,7 +40,7 @@ func TestRelayedTransactionInMultiShardEnvironmentWithNormalTxButWrongNonceShoul _, _ = relayedTx.CreateAndSendRelayedAndUserTx(nodes, relayer, player, receiverAddress2, sendValue, integrationTests.MinTxGasLimit, []byte("")) } - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) time.Sleep(time.Second) @@ -48,7 +48,7 @@ func TestRelayedTransactionInMultiShardEnvironmentWithNormalTxButWrongNonceShoul roundToPropagateMultiShard := int64(20) for i := int64(0); i <= roundToPropagateMultiShard; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) } @@ -74,7 +74,7 @@ func TestRelayedTransactionInMultiShardEnvironmentWithNormalTxButWithTooMuchGas( t.Skip("this is not a short test") } - nodes, idxProposers, players, relayer := relayedTx.CreateGeneralSetupForRelayTxTest(false) + nodes, leaders, players, relayer := relayedTx.CreateGeneralSetupForRelayTxTest(false) defer func() { for _, n := range nodes { n.Close() @@ -105,7 +105,7 @@ func TestRelayedTransactionInMultiShardEnvironmentWithNormalTxButWithTooMuchGas( _, _ = relayedTx.CreateAndSendRelayedAndUserTx(nodes, relayer, player, receiverAddress2, sendValue, tooMuchGasLimit, []byte("")) } - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) time.Sleep(time.Second) @@ -113,7 +113,7 @@ func TestRelayedTransactionInMultiShardEnvironmentWithNormalTxButWithTooMuchGas( roundToPropagateMultiShard := int64(20) for i := int64(0); i <= roundToPropagateMultiShard; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) } diff --git a/integrationTests/multiShard/relayedTx/relayedTx_test.go b/integrationTests/multiShard/relayedTx/relayedTx_test.go index c815a5b5eac..bba732f7a5b 100644 --- a/integrationTests/multiShard/relayedTx/relayedTx_test.go +++ b/integrationTests/multiShard/relayedTx/relayedTx_test.go @@ -9,6 +9,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/esdt" "github.com/multiversx/mx-chain-core-go/data/transaction" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm/wasm" "github.com/multiversx/mx-chain-go/process" @@ -16,9 +20,6 @@ import ( "github.com/multiversx/mx-chain-go/process/smartContract/hooks" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/vm" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) type createAndSendRelayedAndUserTxFuncType = func( diff --git a/integrationTests/multiShard/smartContract/dns/dns_test.go b/integrationTests/multiShard/smartContract/dns/dns_test.go index 1f983617cb1..40f7117e469 100644 --- a/integrationTests/multiShard/smartContract/dns/dns_test.go +++ b/integrationTests/multiShard/smartContract/dns/dns_test.go @@ -12,14 +12,15 @@ import ( "github.com/multiversx/mx-chain-core-go/data/api" "github.com/multiversx/mx-chain-core-go/hashing/keccak" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/genesis" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/multiShard/relayedTx" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestSCCallingDNSUserNames(t *testing.T) { @@ -27,7 +28,7 @@ func TestSCCallingDNSUserNames(t *testing.T) { t.Skip("this is not a short test") } - nodes, players, idxProposers := prepareNodesAndPlayers() + nodes, players, leaders := prepareNodesAndPlayers() defer func() { for _, n := range nodes { n.Close() @@ -45,7 +46,7 @@ func TestSCCallingDNSUserNames(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 25 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) checkUserNamesAreSetCorrectly(t, players, nodes, userNames, sortedDNSAddresses) } @@ -55,7 +56,7 @@ func TestSCCallingDNSUserNamesTwice(t *testing.T) { t.Skip("this is not a short test") } - nodes, players, idxProposers := prepareNodesAndPlayers() + nodes, players, leaders := prepareNodesAndPlayers() defer func() { for _, n := range nodes { n.Close() @@ -73,12 +74,12 @@ func TestSCCallingDNSUserNamesTwice(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) newUserNames := sendRegisterUserNameTxForPlayers(players, nodes, sortedDNSAddresses, dnsRegisterValue) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) checkUserNamesAreSetCorrectly(t, players, nodes, userNames, sortedDNSAddresses) checkUserNamesAreDeleted(t, nodes, newUserNames, sortedDNSAddresses) @@ -89,7 +90,7 @@ func TestDNSandRelayedTxNormal(t *testing.T) { t.Skip("this is not a short test") } - nodes, players, idxProposers := prepareNodesAndPlayers() + nodes, players, leaders := prepareNodesAndPlayers() defer func() { for _, n := range nodes { n.Close() @@ -108,7 +109,7 @@ func TestDNSandRelayedTxNormal(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 30 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) checkUserNamesAreSetCorrectly(t, players, nodes, userNames, sortedDNSAddresses) } @@ -122,7 +123,7 @@ func createAndMintRelayer(nodes []*integrationTests.TestProcessorNode) *integrat return relayer } -func prepareNodesAndPlayers() ([]*integrationTests.TestProcessorNode, []*integrationTests.TestWalletAccount, []int) { +func prepareNodesAndPlayers() ([]*integrationTests.TestProcessorNode, []*integrationTests.TestWalletAccount, []*integrationTests.TestProcessorNode) { numOfShards := 2 nodesPerShard := 1 numMetachainNodes := 1 @@ -143,11 +144,11 @@ func prepareNodesAndPlayers() ([]*integrationTests.TestProcessorNode, []*integra node.EconomicsData.SetMaxGasLimitPerBlock(1500000000, 0) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -163,7 +164,7 @@ func prepareNodesAndPlayers() ([]*integrationTests.TestProcessorNode, []*integra integrationTests.MintAllNodes(nodes, initialVal) integrationTests.MintAllPlayers(nodes, players, initialVal) - return nodes, players, idxProposers + return nodes, players, leaders } func getDNSContractsData(node *integrationTests.TestProcessorNode) (*big.Int, []string) { diff --git a/integrationTests/multiShard/smartContract/polynetworkbridge/bridge_test.go b/integrationTests/multiShard/smartContract/polynetworkbridge/bridge_test.go index b74acc3b392..7ed4de2112c 100644 --- a/integrationTests/multiShard/smartContract/polynetworkbridge/bridge_test.go +++ b/integrationTests/multiShard/smartContract/polynetworkbridge/bridge_test.go @@ -7,6 +7,9 @@ import ( "os" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/process" @@ -14,8 +17,6 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestBridgeSetupAndBurn(t *testing.T) { @@ -31,6 +32,7 @@ func TestBridgeSetupAndBurn(t *testing.T) { GlobalMintBurnDisableEpoch: integrationTests.UnreachableEpoch, SCProcessorV2EnableEpoch: integrationTests.UnreachableEpoch, FixAsyncCallBackArgsListEnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } arwenVersion := config.WasmVMVersionByEpoch{Version: "v1.4"} vmConfig := &config.VirtualMachineConfig{ @@ -48,11 +50,11 @@ func TestBridgeSetupAndBurn(t *testing.T) { ownerNode := nodes[0] shard := nodes[0:nodesPerShard] - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -73,7 +75,7 @@ func TestBridgeSetupAndBurn(t *testing.T) { nonce++ tokenManagerPath := "../testdata/polynetworkbridge/esdt_token_manager.wasm" - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) blockChainHook := ownerNode.BlockchainHook scAddressBytes, _ := blockChainHook.NewAddress( @@ -100,7 +102,7 @@ func TestBridgeSetupAndBurn(t *testing.T) { deploymentData, 100000, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) txValue := big.NewInt(1000) txData := "performWrappedEgldIssue@05" @@ -112,7 +114,7 @@ func TestBridgeSetupAndBurn(t *testing.T) { txData, integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 8, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 8, nonce, round) scQuery := &process.SCQuery{ CallerAddr: ownerNode.OwnAccount.Address, @@ -140,7 +142,7 @@ func TestBridgeSetupAndBurn(t *testing.T) { integrationTests.AdditionalGasLimit, ) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 12, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 12, nonce, round) checkBurnedOnESDTContract(t, nodes, tokenIdentifier, valueToBurn) } diff --git a/integrationTests/multiShard/smartContract/scCallingSC_test.go b/integrationTests/multiShard/smartContract/scCallingSC_test.go index 52b24371d15..9f46f6e8f03 100644 --- a/integrationTests/multiShard/smartContract/scCallingSC_test.go +++ b/integrationTests/multiShard/smartContract/scCallingSC_test.go @@ -16,16 +16,17 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/transaction" vmData "github.com/multiversx/mx-chain-core-go/data/vm" + logger "github.com/multiversx/mx-chain-logger-go" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/state" systemVm "github.com/multiversx/mx-chain-go/vm" - logger "github.com/multiversx/mx-chain-logger-go" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var log = logger.GetOrCreate("integrationtests/multishard/smartcontract") @@ -45,11 +46,10 @@ func TestSCCallingIntraShard(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard integrationTests.DisplayAndStartNodes(nodes) @@ -86,7 +86,7 @@ func TestSCCallingIntraShard(t *testing.T) { nodes, nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - //000000000000000005005d3d53b5d0fcf07d222170978932166ee9f3972d3030 + // 000000000000000005005d3d53b5d0fcf07d222170978932166ee9f3972d3030 secondSCAddress := putDeploySCToDataPool( "./testdata/second/output/second.wasm", secondSCOwner, @@ -96,10 +96,10 @@ func TestSCCallingIntraShard(t *testing.T) { nodes, nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - //00000000000000000500017cc09151c48b99e2a1522fb70a5118ad4cb26c3031 + // 00000000000000000500017cc09151c48b99e2a1522fb70a5118ad4cb26c3031 // Run two rounds, so the two SmartContracts get deployed. - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) @@ -113,7 +113,7 @@ func TestSCCallingIntraShard(t *testing.T) { } time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 3, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, nonce, round) // verify how many times was the first SC called for index, node := range nodes { @@ -142,11 +142,11 @@ func TestScDeployAndChangeScOwner(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numShards+1) for i := 0; i < numShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numShards] = numShards * nodesPerShard + leaders[numShards] = nodes[numShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -176,8 +176,8 @@ func TestScDeployAndChangeScOwner(t *testing.T) { nonce := uint64(0) round = integrationTests.IncrementAndPrintRound(round) nonce++ - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -195,8 +195,8 @@ func TestScDeployAndChangeScOwner(t *testing.T) { for i := 0; i < numRoundsToPropagateMultiShard; i++ { integrationTests.UpdateRound(nodes, round) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ } @@ -222,8 +222,8 @@ func TestScDeployAndChangeScOwner(t *testing.T) { for i := 0; i < numRoundsToPropagateMultiShard; i++ { integrationTests.UpdateRound(nodes, round) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ } @@ -252,11 +252,11 @@ func TestScDeployAndClaimSmartContractDeveloperRewards(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numShards+1) for i := 0; i < numShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numShards] = numShards * nodesPerShard + leaders[numShards] = nodes[numShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -289,8 +289,8 @@ func TestScDeployAndClaimSmartContractDeveloperRewards(t *testing.T) { nonce := uint64(0) round = integrationTests.IncrementAndPrintRound(round) nonce++ - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -308,8 +308,8 @@ func TestScDeployAndClaimSmartContractDeveloperRewards(t *testing.T) { for i := 0; i < 5; i++ { integrationTests.UpdateRound(nodes, round) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ } @@ -346,8 +346,8 @@ func TestScDeployAndClaimSmartContractDeveloperRewards(t *testing.T) { for i := 0; i < 3; i++ { integrationTests.UpdateRound(nodes, round) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) - integrationTests.SyncBlock(t, nodes, idxProposers, round) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) + integrationTests.SyncBlock(t, nodes, leaders, round) round = integrationTests.IncrementAndPrintRound(round) nonce++ } @@ -381,11 +381,11 @@ func TestSCCallingInCrossShard(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -422,7 +422,7 @@ func TestSCCallingInCrossShard(t *testing.T) { nodes, nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - //000000000000000005005d3d53b5d0fcf07d222170978932166ee9f3972d3030 + // 000000000000000005005d3d53b5d0fcf07d222170978932166ee9f3972d3030 secondSCAddress := putDeploySCToDataPool( "./testdata/second/output/second.wasm", secondSCOwner, @@ -432,9 +432,9 @@ func TestSCCallingInCrossShard(t *testing.T) { nodes, nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - //00000000000000000500017cc09151c48b99e2a1522fb70a5118ad4cb26c3031 + // 00000000000000000500017cc09151c48b99e2a1522fb70a5118ad4cb26c3031 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // make smart contract call to shard 1 which will do in shard 0 for _, node := range nodes { @@ -452,7 +452,7 @@ func TestSCCallingInCrossShard(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 10 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) // verify how many times was shard 0 and shard 1 called shId := nodes[0].ShardCoordinator.ComputeId(firstSCAddress) @@ -518,11 +518,11 @@ func TestSCCallingBuiltinAndFails(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -557,7 +557,7 @@ func TestSCCallingBuiltinAndFails(t *testing.T) { nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) sender := nodes[0] receiver := nodes[1] @@ -576,7 +576,7 @@ func TestSCCallingBuiltinAndFails(t *testing.T) { ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) testValue1 := vm.GetIntValueFromSC(nil, sender.AccntState, scAddress, "testValue1", nil) require.NotNil(t, testValue1) require.Equal(t, uint64(255), testValue1.Uint64()) @@ -606,18 +606,16 @@ func TestSCCallingInCrossShardDelegationMock(t *testing.T) { ) nodes := make([]*integrationTests.TestProcessorNode, 0) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for _, nds := range nodesMap { nodes = append(nodes, nds...) } - for _, nds := range nodesMap { - idx, err := getNodeIndex(nodes, nds[0]) - assert.Nil(t, err) - - idxProposers = append(idxProposers, idx) + for i := 0; i < numOfShards; i++ { + leaders[i] = nodesMap[uint32(i)][0] } + leaders[numOfShards] = nodesMap[core.MetachainShardId][0] integrationTests.DisplayAndStartNodes(nodes) @@ -652,7 +650,7 @@ func TestSCCallingInCrossShardDelegationMock(t *testing.T) { nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // one node calls to stake all the money from the delegation - that's how the contract is :D node := nodes[0] @@ -665,7 +663,7 @@ func TestSCCallingInCrossShardDelegationMock(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 10 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // verify system smart contract has the value @@ -707,18 +705,16 @@ func TestSCCallingInCrossShardDelegation(t *testing.T) { ) nodes := make([]*integrationTests.TestProcessorNode, 0) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for _, nds := range nodesMap { nodes = append(nodes, nds...) } - for _, nds := range nodesMap { - idx, err := getNodeIndex(nodes, nds[0]) - assert.Nil(t, err) - - idxProposers = append(idxProposers, idx) + for i := 0; i < numOfShards; i++ { + leaders[i] = nodesMap[uint32(i)][0] } + leaders[numOfShards] = nodesMap[core.MetachainShardId][0] integrationTests.DisplayAndStartNodes(nodes) @@ -761,7 +757,7 @@ func TestSCCallingInCrossShardDelegation(t *testing.T) { ) shardNode.OwnAccount.Nonce++ - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // check that the version is the expected one scQueryVersion := &process.SCQuery{ @@ -775,13 +771,13 @@ func TestSCCallingInCrossShardDelegation(t *testing.T) { require.True(t, bytes.Contains(vmOutputVersion.ReturnData[0], []byte("0.3."))) log.Info("SC deployed", "version", string(vmOutputVersion.ReturnData[0])) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // set stake per node setStakePerNodeTxData := "setStakePerNode@" + core.ConvertToEvenHexBigInt(nodePrice) integrationTests.CreateAndSendTransaction(shardNode, nodes, big.NewInt(0), delegateSCAddress, setStakePerNodeTxData, integrationTests.AdditionalGasLimit) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // add node addNodesTxData := fmt.Sprintf("addNodes@%s@%s", @@ -789,25 +785,25 @@ func TestSCCallingInCrossShardDelegation(t *testing.T) { hex.EncodeToString(stakerBLSSignature)) integrationTests.CreateAndSendTransaction(shardNode, nodes, big.NewInt(0), delegateSCAddress, addNodesTxData, integrationTests.AdditionalGasLimit) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // stake some coin! // here the node account fills all the required stake stakeTxData := "stake" integrationTests.CreateAndSendTransaction(shardNode, nodes, totalStake, delegateSCAddress, stakeTxData, integrationTests.AdditionalGasLimit) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) // activate the delegation, this involves an async call to validatorSC stakeAllAvailableTxData := "stakeAllAvailable" integrationTests.CreateAndSendTransaction(shardNode, nodes, big.NewInt(0), delegateSCAddress, stakeAllAvailableTxData, 2*integrationTests.AdditionalGasLimit) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 10 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) @@ -890,11 +886,10 @@ func TestSCNonPayableIntraShardErrorShouldProcessBlock(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard integrationTests.DisplayAndStartNodes(nodes) @@ -931,7 +926,7 @@ func TestSCNonPayableIntraShardErrorShouldProcessBlock(t *testing.T) { nodes, nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - //000000000000000005005d3d53b5d0fcf07d222170978932166ee9f3972d3030 + // 000000000000000005005d3d53b5d0fcf07d222170978932166ee9f3972d3030 secondSCAddress := putDeploySCToDataPool( "./testdata/second/output/second.wasm", secondSCOwner, @@ -941,10 +936,10 @@ func TestSCNonPayableIntraShardErrorShouldProcessBlock(t *testing.T) { nodes, nodes[0].EconomicsData.MaxGasLimitPerBlock(0)-1, ) - //00000000000000000500017cc09151c48b99e2a1522fb70a5118ad4cb26c3031 + // 00000000000000000500017cc09151c48b99e2a1522fb70a5118ad4cb26c3031 // Run two rounds, so the two SmartContracts get deployed. - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) @@ -958,23 +953,13 @@ func TestSCNonPayableIntraShardErrorShouldProcessBlock(t *testing.T) { } time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 3, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, nonce, round) for _, node := range nodes { assert.Equal(t, uint64(5), node.BlockChain.GetCurrentBlockHeader().GetNonce()) } } -func getNodeIndex(nodeList []*integrationTests.TestProcessorNode, node *integrationTests.TestProcessorNode) (int, error) { - for i := range nodeList { - if node == nodeList[i] { - return i, nil - } - } - - return 0, errors.New("no such node in list") -} - func putDeploySCToDataPool( fileName string, pubkey []byte, diff --git a/integrationTests/multiShard/softfork/scDeploy_test.go b/integrationTests/multiShard/softfork/scDeploy_test.go index 8af125f5797..5b4252b7806 100644 --- a/integrationTests/multiShard/softfork/scDeploy_test.go +++ b/integrationTests/multiShard/softfork/scDeploy_test.go @@ -11,12 +11,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" - "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/process/factory" - "github.com/multiversx/mx-chain-go/state" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/integrationTests" + "github.com/multiversx/mx-chain-go/process/factory" + "github.com/multiversx/mx-chain-go/state" ) var log = logger.GetOrCreate("integrationtests/singleshard/block/softfork") @@ -67,7 +68,7 @@ func TestScDeploy(t *testing.T) { } integrationTests.ConnectNodes(connectableNodes) - idxProposers := []int{0, 1} + leaders := []*integrationTests.TestProcessorNode{nodes[0], nodes[1]} defer func() { for _, n := range nodes { @@ -93,7 +94,7 @@ func TestScDeploy(t *testing.T) { for i := uint64(0); i < numRounds; i++ { integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -108,7 +109,7 @@ func TestScDeploy(t *testing.T) { deploySucceeded := deploySc(t, nodes) for i := uint64(0); i < 5; i++ { integrationTests.UpdateRound(nodes, round) - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) round = integrationTests.IncrementAndPrintRound(round) nonce++ diff --git a/integrationTests/multiShard/txScenarios/builtinFunctions_test.go b/integrationTests/multiShard/txScenarios/builtinFunctions_test.go index 1064239cbb0..0285cd0f5fd 100644 --- a/integrationTests/multiShard/txScenarios/builtinFunctions_test.go +++ b/integrationTests/multiShard/txScenarios/builtinFunctions_test.go @@ -8,9 +8,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/process/factory" - "github.com/stretchr/testify/assert" ) func TestTransaction_TransactionBuiltinFunctionsScenarios(t *testing.T) { @@ -19,7 +20,7 @@ func TestTransaction_TransactionBuiltinFunctionsScenarios(t *testing.T) { } initialBalance := big.NewInt(1000000000000) - nodes, idxProposers, players := createGeneralSetupForTxTest(initialBalance) + nodes, leaders, players := createGeneralSetupForTxTest(initialBalance) defer func() { for _, n := range nodes { n.Close() @@ -50,7 +51,7 @@ func TestTransaction_TransactionBuiltinFunctionsScenarios(t *testing.T) { nrRoundsToTest := int64(5) for i := int64(0); i < nrRoundsToTest; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) time.Sleep(time.Second) @@ -74,7 +75,7 @@ func TestTransaction_TransactionBuiltinFunctionsScenarios(t *testing.T) { time.Sleep(time.Millisecond) for i := int64(0); i < nrRoundsToTest; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) time.Sleep(time.Second) @@ -103,7 +104,7 @@ func TestTransaction_TransactionBuiltinFunctionsScenarios(t *testing.T) { time.Sleep(time.Millisecond) for i := int64(0); i < nrRoundsToTest; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) time.Sleep(time.Second) } diff --git a/integrationTests/multiShard/txScenarios/common.go b/integrationTests/multiShard/txScenarios/common.go index d720b9d8df5..78c11d8f2df 100644 --- a/integrationTests/multiShard/txScenarios/common.go +++ b/integrationTests/multiShard/txScenarios/common.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/state" @@ -29,7 +30,7 @@ func createGeneralTestnetForTxTest( func createGeneralSetupForTxTest(initialBalance *big.Int) ( []*integrationTests.TestProcessorNode, - []int, + []*integrationTests.TestProcessorNode, []*integrationTests.TestWalletAccount, ) { numOfShards := 2 @@ -40,6 +41,7 @@ func createGeneralSetupForTxTest(initialBalance *big.Int) ( OptimizeGasUsedInCrossMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -49,11 +51,11 @@ func createGeneralSetupForTxTest(initialBalance *big.Int) ( enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -68,7 +70,7 @@ func createGeneralSetupForTxTest(initialBalance *big.Int) ( integrationTests.MintAllPlayers(nodes, players, initialBalance) - return nodes, idxProposers, players + return nodes, leaders, players } func createAndSendTransaction( diff --git a/integrationTests/multiShard/txScenarios/moveBalance_test.go b/integrationTests/multiShard/txScenarios/moveBalance_test.go index 5df383f7ebb..8599e5a45db 100644 --- a/integrationTests/multiShard/txScenarios/moveBalance_test.go +++ b/integrationTests/multiShard/txScenarios/moveBalance_test.go @@ -6,9 +6,10 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core/pubkeyConverter" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/vm" - "github.com/stretchr/testify/assert" ) func TestTransaction_TransactionMoveBalanceScenarios(t *testing.T) { @@ -17,7 +18,7 @@ func TestTransaction_TransactionMoveBalanceScenarios(t *testing.T) { } initialBalance := big.NewInt(1000000000000) - nodes, idxProposers, players := createGeneralSetupForTxTest(initialBalance) + nodes, leaders, players := createGeneralSetupForTxTest(initialBalance) defer func() { for _, n := range nodes { n.Close() @@ -65,7 +66,7 @@ func TestTransaction_TransactionMoveBalanceScenarios(t *testing.T) { nrRoundsToTest := int64(7) for i := int64(0); i < nrRoundsToTest; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) time.Sleep(integrationTests.StepDelay) @@ -80,7 +81,7 @@ func TestTransaction_TransactionMoveBalanceScenarios(t *testing.T) { assert.Equal(t, players[2].Nonce, senderAccount.GetNonce()) assert.Equal(t, expectedBalance, senderAccount.GetBalance()) - //check balance intra shard tx insufficient gas limit + // check balance intra shard tx insufficient gas limit senderAccount = getUserAccount(nodes, players[1].Address) assert.Equal(t, uint64(0), senderAccount.GetNonce()) assert.Equal(t, initialBalance, senderAccount.GetBalance()) @@ -116,7 +117,7 @@ func TestTransaction_TransactionMoveBalanceScenarios(t *testing.T) { roundToPropagateMultiShard := int64(15) for i := int64(0); i <= roundToPropagateMultiShard; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) time.Sleep(integrationTests.StepDelay) } diff --git a/integrationTests/multiShard/validatorToDelegation/validatorToDelegation_test.go b/integrationTests/multiShard/validatorToDelegation/validatorToDelegation_test.go index b28c5dc054e..06e6d8892c7 100644 --- a/integrationTests/multiShard/validatorToDelegation/validatorToDelegation_test.go +++ b/integrationTests/multiShard/validatorToDelegation/validatorToDelegation_test.go @@ -8,13 +8,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestValidatorToDelegationManagerWithNewContract(t *testing.T) { @@ -34,11 +35,11 @@ func TestValidatorToDelegationManagerWithNewContract(t *testing.T) { stakingWalletAccount := integrationTests.CreateTestWalletAccount(nodes[0].ShardCoordinator, nodes[0].ShardCoordinator.SelfId()) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -71,7 +72,7 @@ func TestValidatorToDelegationManagerWithNewContract(t *testing.T) { t, nodes, stakingWalletAccount, - idxProposers, + leaders, nodePrice, frontendBLSPubkey, frontendHexSignature, @@ -87,7 +88,7 @@ func TestValidatorToDelegationManagerWithNewContract(t *testing.T) { t, nodes, stakingWalletAccount, - idxProposers, + leaders, "makeNewContractFromValidatorData", big.NewInt(0), []byte{10}, @@ -124,11 +125,11 @@ func testValidatorToDelegationWithMerge(t *testing.T, withJail bool) { stakingWalletAccount := integrationTests.CreateTestWalletAccount(nodes[0].ShardCoordinator, nodes[0].ShardCoordinator.SelfId()) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -162,7 +163,7 @@ func testValidatorToDelegationWithMerge(t *testing.T, withJail bool) { t, nodes, stakingWalletAccount, - idxProposers, + leaders, nodePrice, frontendBLSPubkey, frontendHexSignature, @@ -182,7 +183,7 @@ func testValidatorToDelegationWithMerge(t *testing.T, withJail bool) { t, nodes, stakingWalletAccount, - idxProposers, + leaders, "createNewDelegationContract", big.NewInt(10000), []byte{0}, @@ -206,7 +207,7 @@ func testValidatorToDelegationWithMerge(t *testing.T, withJail bool) { time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) time.Sleep(time.Second) @@ -258,11 +259,11 @@ func TestValidatorToDelegationManagerWithWhiteListAndMerge(t *testing.T) { stakingWalletAccount1 := integrationTests.CreateTestWalletAccount(nodes[0].ShardCoordinator, nodes[0].ShardCoordinator.SelfId()) stakingWalletAccount2 := integrationTests.CreateTestWalletAccount(nodes[0].ShardCoordinator, nodes[0].ShardCoordinator.SelfId()) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -296,7 +297,7 @@ func TestValidatorToDelegationManagerWithWhiteListAndMerge(t *testing.T) { t, nodes, stakingWalletAccount1, - idxProposers, + leaders, nodePrice, frontendBLSPubkey, frontendHexSignature, @@ -312,7 +313,7 @@ func TestValidatorToDelegationManagerWithWhiteListAndMerge(t *testing.T) { t, nodes, stakingWalletAccount2, - idxProposers, + leaders, "createNewDelegationContract", big.NewInt(10000), []byte{0}, @@ -335,7 +336,7 @@ func TestValidatorToDelegationManagerWithWhiteListAndMerge(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 5, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 5, nonce, round) txData = txDataBuilder.NewBuilder().Clear(). Func("mergeValidatorToDelegationWithWhitelist"). @@ -352,7 +353,7 @@ func TestValidatorToDelegationManagerWithWhiteListAndMerge(t *testing.T) { time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) time.Sleep(time.Second) testBLSKeyOwnerIsAddress(t, nodes, scAddressBytes, frontendBLSPubkey) @@ -378,7 +379,7 @@ func generateSendAndWaitToExecuteStakeTransaction( t *testing.T, nodes []*integrationTests.TestProcessorNode, stakingWalletAccount *integrationTests.TestWalletAccount, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, nodePrice *big.Int, frontendBLSPubkey []byte, frontendHexSignature string, @@ -398,7 +399,7 @@ func generateSendAndWaitToExecuteStakeTransaction( time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 6 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) return nonce, round } @@ -407,7 +408,7 @@ func generateSendAndWaitToExecuteTransaction( t *testing.T, nodes []*integrationTests.TestProcessorNode, stakingWalletAccount *integrationTests.TestWalletAccount, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, function string, value *big.Int, serviceFee []byte, @@ -431,7 +432,7 @@ func generateSendAndWaitToExecuteTransaction( time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) return nonce, round } diff --git a/integrationTests/node/getAccount/getAccount_test.go b/integrationTests/node/getAccount/getAccount_test.go index 487c8b1a15a..acb4e92fd75 100644 --- a/integrationTests/node/getAccount/getAccount_test.go +++ b/integrationTests/node/getAccount/getAccount_test.go @@ -7,13 +7,16 @@ import ( chainData "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/api" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/node" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/blockInfoProviders" "github.com/multiversx/mx-chain-go/testscommon" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createAccountsRepository(accDB state.AccountsAdapter, blockchain chainData.ChainHandler) state.AccountsRepository { @@ -39,7 +42,9 @@ func TestNode_GetAccountAccountDoesNotExistsShouldRetEmpty(t *testing.T) { accDB, _ := integrationTests.CreateAccountsDB(0, trieStorage) rootHash, _ := accDB.Commit() - coreComponents := integrationTests.GetDefaultCoreComponents(integrationTests.CreateEnableEpochsConfig()) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(integrationTests.CreateEnableEpochsConfig(), genericEpochNotifier) + coreComponents := integrationTests.GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.AddressPubKeyConverterField = integrationTests.TestAddressPubkeyConverter dataComponents := integrationTests.GetDefaultDataComponents() @@ -81,7 +86,9 @@ func TestNode_GetAccountAccountExistsShouldReturn(t *testing.T) { testPubkey := integrationTests.CreateAccount(accDB, testNonce, testBalance) rootHash, _ := accDB.Commit() - coreComponents := integrationTests.GetDefaultCoreComponents(integrationTests.CreateEnableEpochsConfig()) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(integrationTests.CreateEnableEpochsConfig(), genericEpochNotifier) + coreComponents := integrationTests.GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.AddressPubKeyConverterField = testscommon.RealWorldBech32PubkeyConverter dataComponents := integrationTests.GetDefaultDataComponents() diff --git a/integrationTests/nodesCoordinatorFactory.go b/integrationTests/nodesCoordinatorFactory.go index 28267d44c5a..8154d6df5db 100644 --- a/integrationTests/nodesCoordinatorFactory.go +++ b/integrationTests/nodesCoordinatorFactory.go @@ -7,10 +7,12 @@ import ( "github.com/multiversx/mx-chain-core-go/data/endProcess" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests/mock" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" @@ -46,10 +48,6 @@ func (tpn *IndexHashedNodesCoordinatorFactory) CreateNodesCoordinator(arg ArgInd pubKeyBytes, _ := keys.MainKey.Pk.ToByteArray() nodeShufflerArgs := &nodesCoordinator.NodesShufflerArgs{ - NodesShard: uint32(arg.nodesPerShard), - NodesMeta: uint32(arg.nbMetaNodes), - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, @@ -61,23 +59,33 @@ func (tpn *IndexHashedNodesCoordinatorFactory) CreateNodesCoordinator(arg ArgInd StakingV4Step2EnableEpoch, ) argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ - ShardConsensusGroupSize: arg.shardConsensusGroupSize, - MetaConsensusGroupSize: arg.metaConsensusGroupSize, - Marshalizer: TestMarshalizer, - Hasher: arg.hasher, - Shuffler: nodeShuffler, - EpochStartNotifier: arg.epochStartSubscriber, - ShardIDAsObserver: arg.shardId, - NbShards: uint32(arg.nbShards), - EligibleNodes: arg.validatorsMap, - WaitingNodes: arg.waitingMap, - SelfPublicKey: pubKeyBytes, - ConsensusGroupCache: arg.consensusGroupCache, - BootStorer: arg.bootStorer, - ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, - ChanStopNode: endProcess.GetDummyEndProcessChannel(), - NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, - IsFullArchive: false, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardMinNumNodes: uint32(arg.nodesPerShard), + MetachainMinNumNodes: uint32(arg.nbMetaNodes), + Hysteresis: hysteresis, + Adaptivity: adaptivity, + ShardConsensusGroupSize: uint32(arg.shardConsensusGroupSize), + MetachainConsensusGroupSize: uint32(arg.metaConsensusGroupSize), + }, nil + }, + }, + Marshalizer: TestMarshalizer, + Hasher: arg.hasher, + Shuffler: nodeShuffler, + EpochStartNotifier: arg.epochStartSubscriber, + ShardIDAsObserver: arg.shardId, + NbShards: uint32(arg.nbShards), + EligibleNodes: arg.validatorsMap, + WaitingNodes: arg.waitingMap, + SelfPublicKey: pubKeyBytes, + ConsensusGroupCache: arg.consensusGroupCache, + BootStorer: arg.bootStorer, + ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, + ChanStopNode: endProcess.GetDummyEndProcessChannel(), + NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, + IsFullArchive: false, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ GetActivationEpochCalled: func(flag core.EnableEpochFlag) uint32 { if flag == common.RefactorPeersMiniBlocksFlag || flag == common.StakingV4Step2Flag { @@ -112,10 +120,6 @@ func (ihncrf *IndexHashedNodesCoordinatorWithRaterFactory) CreateNodesCoordinato pubKeyBytes, _ := keys.MainKey.Pk.ToByteArray() shufflerArgs := &nodesCoordinator.NodesShufflerArgs{ - NodesShard: uint32(arg.nodesPerShard), - NodesMeta: uint32(arg.nbMetaNodes), - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, @@ -127,23 +131,33 @@ func (ihncrf *IndexHashedNodesCoordinatorWithRaterFactory) CreateNodesCoordinato StakingV4Step2EnableEpoch, ) argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ - ShardConsensusGroupSize: arg.shardConsensusGroupSize, - MetaConsensusGroupSize: arg.metaConsensusGroupSize, - Marshalizer: TestMarshalizer, - Hasher: arg.hasher, - Shuffler: nodeShuffler, - EpochStartNotifier: arg.epochStartSubscriber, - ShardIDAsObserver: arg.shardId, - NbShards: uint32(arg.nbShards), - EligibleNodes: arg.validatorsMap, - WaitingNodes: arg.waitingMap, - SelfPublicKey: pubKeyBytes, - ConsensusGroupCache: arg.consensusGroupCache, - BootStorer: arg.bootStorer, - ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, - ChanStopNode: endProcess.GetDummyEndProcessChannel(), - NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, - IsFullArchive: false, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardMinNumNodes: uint32(arg.nodesPerShard), + MetachainMinNumNodes: uint32(arg.nbMetaNodes), + Hysteresis: hysteresis, + Adaptivity: adaptivity, + ShardConsensusGroupSize: uint32(arg.shardConsensusGroupSize), + MetachainConsensusGroupSize: uint32(arg.metaConsensusGroupSize), + }, nil + }, + }, + Marshalizer: TestMarshalizer, + Hasher: arg.hasher, + Shuffler: nodeShuffler, + EpochStartNotifier: arg.epochStartSubscriber, + ShardIDAsObserver: arg.shardId, + NbShards: uint32(arg.nbShards), + EligibleNodes: arg.validatorsMap, + WaitingNodes: arg.waitingMap, + SelfPublicKey: pubKeyBytes, + ConsensusGroupCache: arg.consensusGroupCache, + BootStorer: arg.bootStorer, + ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, + ChanStopNode: endProcess.GetDummyEndProcessChannel(), + NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, + IsFullArchive: false, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ GetActivationEpochCalled: func(flag core.EnableEpochFlag) uint32 { if flag == common.RefactorPeersMiniBlocksFlag { diff --git a/integrationTests/p2p/antiflood/messageProcessor.go b/integrationTests/p2p/antiflood/messageProcessor.go index 5f56985861f..b61400f397e 100644 --- a/integrationTests/p2p/antiflood/messageProcessor.go +++ b/integrationTests/p2p/antiflood/messageProcessor.go @@ -5,6 +5,7 @@ import ( "sync/atomic" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/integrationTests/mock" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process" @@ -30,7 +31,7 @@ func newMessageProcessor() *MessageProcessor { } // ProcessReceivedMessage is the callback function from the p2p side whenever a new message is received -func (mp *MessageProcessor) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, _ p2p.MessageHandler) error { +func (mp *MessageProcessor) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, _ p2p.MessageHandler) ([]byte, error) { atomic.AddUint32(&mp.numMessagesReceived, 1) atomic.AddUint64(&mp.sizeMessagesReceived, uint64(len(message.Data()))) @@ -38,7 +39,7 @@ func (mp *MessageProcessor) ProcessReceivedMessage(message p2p.MessageP2P, fromC af, _ := antiflood2.NewP2PAntiflood(&mock.PeerBlackListCacherStub{}, &mock.TopicAntiFloodStub{}, mp.FloodPreventer) err := af.CanProcessMessage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } } @@ -50,7 +51,7 @@ func (mp *MessageProcessor) ProcessReceivedMessage(message p2p.MessageP2P, fromC mp.messages[fromConnectedPeer] = append(mp.messages[fromConnectedPeer], message) - return nil + return []byte{}, nil } // NumMessagesProcessed returns the number of processed messages diff --git a/integrationTests/realcomponents/processorRunner.go b/integrationTests/realcomponents/processorRunner.go index 3f3f4837201..20a33dcffc8 100644 --- a/integrationTests/realcomponents/processorRunner.go +++ b/integrationTests/realcomponents/processorRunner.go @@ -95,7 +95,7 @@ func (pr *ProcessorRunner) createCoreComponents(tb testing.TB) { RatingsConfig: *pr.Config.RatingsConfig, EconomicsConfig: *pr.Config.EconomicsConfig, ImportDbConfig: *pr.Config.ImportDbConfig, - NodesFilename: pr.Config.ConfigurationPathsHolder.Nodes, + NodesConfig: *pr.Config.NodesConfig, WorkingDirectory: pr.Config.FlagsConfig.WorkingDir, ChanStopNodeProcess: make(chan endProcess.ArgEndProcess), } @@ -308,6 +308,7 @@ func (pr *ProcessorRunner) createStatusComponents(tb testing.TB) { pr.CoreComponents.EnableEpochsHandler(), pr.DataComponents.Datapool().CurrentEpochValidatorInfo(), pr.BootstrapComponents.NodesCoordinatorRegistryFactory(), + pr.CoreComponents.ChainParametersHandler(), ) require.Nil(tb, err) diff --git a/integrationTests/realcomponents/processorRunner_test.go b/integrationTests/realcomponents/processorRunner_test.go index 78d0013597e..ce2e60a48d3 100644 --- a/integrationTests/realcomponents/processorRunner_test.go +++ b/integrationTests/realcomponents/processorRunner_test.go @@ -3,8 +3,9 @@ package realcomponents import ( "testing" - "github.com/multiversx/mx-chain-go/testscommon" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/testscommon" ) func TestNewProcessorRunnerAndClose(t *testing.T) { diff --git a/integrationTests/singleShard/block/executingMiniblocks/executingMiniblocks_test.go b/integrationTests/singleShard/block/executingMiniblocks/executingMiniblocks_test.go index 2c7bb0f7a7c..b09e7892a17 100644 --- a/integrationTests/singleShard/block/executingMiniblocks/executingMiniblocks_test.go +++ b/integrationTests/singleShard/block/executingMiniblocks/executingMiniblocks_test.go @@ -11,12 +11,13 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/integrationTests" testBlock "github.com/multiversx/mx-chain-go/integrationTests/singleShard/block" "github.com/multiversx/mx-chain-go/process" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/assert" ) // TestShardShouldNotProposeAndExecuteTwoBlocksInSameRound tests that a shard can not continue building on a @@ -43,6 +44,7 @@ func TestShardShouldNotProposeAndExecuteTwoBlocksInSameRound(t *testing.T) { integrationTests.ConnectNodes(connectableNodes) idxProposer := 0 + leader := nodes[idxProposer] defer func() { for _, n := range nodes { @@ -57,24 +59,24 @@ func TestShardShouldNotProposeAndExecuteTwoBlocksInSameRound(t *testing.T) { nonce := uint64(1) round = integrationTests.IncrementAndPrintRound(round) - err := proposeAndCommitBlock(nodes[idxProposer], round, nonce) + err := proposeAndCommitBlock(leader, round, nonce) assert.Nil(t, err) - integrationTests.SyncBlock(t, nodes, []int{idxProposer}, nonce) + integrationTests.SyncBlock(t, nodes, []*integrationTests.TestProcessorNode{leader}, nonce) time.Sleep(testBlock.StepDelay) checkCurrentBlockHeight(t, nodes, nonce) - //only nonce increases, round stays the same + // only nonce increases, round stays the same nonce++ err = proposeAndCommitBlock(nodes[idxProposer], round, nonce) assert.Equal(t, process.ErrLowerRoundInBlock, err) - //mockTestingT is used as in normal case SyncBlock would fail as it doesn't find the header with nonce 2 + // mockTestingT is used as in normal case SyncBlock would fail as it doesn't find the header with nonce 2 mockTestingT := &testing.T{} - integrationTests.SyncBlock(mockTestingT, nodes, []int{idxProposer}, nonce) + integrationTests.SyncBlock(mockTestingT, nodes, []*integrationTests.TestProcessorNode{leader}, nonce) time.Sleep(testBlock.StepDelay) @@ -82,11 +84,11 @@ func TestShardShouldNotProposeAndExecuteTwoBlocksInSameRound(t *testing.T) { } // TestShardShouldProposeBlockContainingInvalidTransactions tests the following scenario: -// 1. generate 3 move balance transactions: one that can be executed, one to be processed as invalid, and one that isn't executable (no balance left for fee). -// 2. proposer will have those 3 transactions in its pools and will propose a block -// 3. another node will be able to sync the proposed block (and request - receive) the 2 transactions that -// will end up in the block (one valid and one invalid) -// 4. the non-executable transaction will not be immediately removed from the proposer's pool. See MX-16200. +// 1. generate 3 move balance transactions: one that can be executed, one to be processed as invalid, and one that isn't executable (no balance left for fee). +// 2. proposer will have those 3 transactions in its pools and will propose a block +// 3. another node will be able to sync the proposed block (and request - receive) the 2 transactions that +// will end up in the block (one valid and one invalid) +// 4. the non-executable transaction will not be immediately removed from the proposer's pool. See MX-16200. func TestShardShouldProposeBlockContainingInvalidTransactions(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") @@ -109,7 +111,7 @@ func TestShardShouldProposeBlockContainingInvalidTransactions(t *testing.T) { integrationTests.ConnectNodes(connectableNodes) idxProposer := 0 - proposer := nodes[idxProposer] + leader := nodes[idxProposer] defer func() { for _, n := range nodes { @@ -127,10 +129,10 @@ func TestShardShouldProposeBlockContainingInvalidTransactions(t *testing.T) { transferValue := uint64(1000000) mintAllNodes(nodes, transferValue) - txs, hashes := generateTransferTxs(transferValue, proposer.OwnAccount.SkTxSign, nodes[1].OwnAccount.PkTxSign) - addTxsInDataPool(proposer, txs, hashes) + txs, hashes := generateTransferTxs(transferValue, leader.OwnAccount.SkTxSign, nodes[1].OwnAccount.PkTxSign) + addTxsInDataPool(leader, txs, hashes) - _, _ = integrationTests.ProposeAndSyncOneBlock(t, nodes, []int{idxProposer}, round, nonce) + _, _ = integrationTests.ProposeAndSyncOneBlock(t, nodes, []*integrationTests.TestProcessorNode{leader}, round, nonce) fmt.Println(integrationTests.MakeDisplayTable(nodes)) @@ -218,7 +220,6 @@ func testSameBlockHeight(t *testing.T, nodes []*integrationTests.TestProcessorNo } } - func testTxIsInMiniblock(t *testing.T, proposer *integrationTests.TestProcessorNode, hash []byte, bt block.Type) { hdrHandler := proposer.BlockChain.GetCurrentBlockHeader() hdr := hdrHandler.(*block.Header) diff --git a/integrationTests/singleShard/block/executingMiniblocksSc/executingMiniblocksSc_test.go b/integrationTests/singleShard/block/executingMiniblocksSc/executingMiniblocksSc_test.go index 81bf80dca55..238503d006a 100644 --- a/integrationTests/singleShard/block/executingMiniblocksSc/executingMiniblocksSc_test.go +++ b/integrationTests/singleShard/block/executingMiniblocksSc/executingMiniblocksSc_test.go @@ -9,10 +9,11 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/singleShard/block" "github.com/multiversx/mx-chain-go/process/factory" - "github.com/stretchr/testify/assert" ) func TestShouldProcessMultipleERC20ContractsInSingleShard(t *testing.T) { @@ -40,10 +41,11 @@ func TestShouldProcessMultipleERC20ContractsInSingleShard(t *testing.T) { integrationTests.ConnectNodes(connectableNodes) idxProposer := 0 + leader := nodes[idxProposer] numPlayers := 10 players := make([]*integrationTests.TestWalletAccount, numPlayers) for i := 0; i < numPlayers; i++ { - players[i] = integrationTests.CreateTestWalletAccount(nodes[idxProposer].ShardCoordinator, 0) + players[i] = integrationTests.CreateTestWalletAccount(leader.ShardCoordinator, 0) } defer func() { @@ -62,7 +64,7 @@ func TestShouldProcessMultipleERC20ContractsInSingleShard(t *testing.T) { hardCodedSk, _ := hex.DecodeString("5561d28b0d89fa425bbbf9e49a018b5d1e4a462c03d2efce60faf9ddece2af06") hardCodedScResultingAddress, _ := hex.DecodeString("000000000000000005006c560111a94e434413c1cdaafbc3e1348947d1d5b3a1") - nodes[idxProposer].LoadTxSignSkBytes(hardCodedSk) + leader.LoadTxSignSkBytes(hardCodedSk) initialVal := big.NewInt(100000000000) integrationTests.MintAllNodes(nodes, initialVal) @@ -70,11 +72,11 @@ func TestShouldProcessMultipleERC20ContractsInSingleShard(t *testing.T) { integrationTests.DeployScTx(nodes, idxProposer, hex.EncodeToString(scCode), factory.WasmVirtualMachine, "001000000000") time.Sleep(block.StepDelay) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, []int{idxProposer}, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, []*integrationTests.TestProcessorNode{leader}, round, nonce) playersDoTopUp(nodes[idxProposer], players, hardCodedScResultingAddress, big.NewInt(10000000)) time.Sleep(block.StepDelay) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, []int{idxProposer}, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, []*integrationTests.TestProcessorNode{leader}, round, nonce) for i := 0; i < 100; i++ { playersDoTransfer(nodes[idxProposer], players, hardCodedScResultingAddress, big.NewInt(100)) @@ -82,7 +84,7 @@ func TestShouldProcessMultipleERC20ContractsInSingleShard(t *testing.T) { for i := 0; i < 10; i++ { time.Sleep(block.StepDelay) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, []int{idxProposer}, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, []*integrationTests.TestProcessorNode{leader}, round, nonce) } integrationTests.CheckRootHashes(t, nodes, []int{idxProposer}) diff --git a/integrationTests/state/stateTrie/stateTrie_test.go b/integrationTests/state/stateTrie/stateTrie_test.go index 12ec5115d28..269e6ff103f 100644 --- a/integrationTests/state/stateTrie/stateTrie_test.go +++ b/integrationTests/state/stateTrie/stateTrie_test.go @@ -24,6 +24,10 @@ import ( dataTx "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-core-go/hashing/sha256" crypto "github.com/multiversx/mx-chain-crypto-go" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/common/holders" @@ -49,9 +53,6 @@ import ( stateMock "github.com/multiversx/mx-chain-go/testscommon/state" testStorage "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const denomination = "000000000000000000" @@ -1299,7 +1300,7 @@ func TestRollbackBlockAndCheckThatPruningIsCancelledOnAccountsTrie(t *testing.T) numNodesPerShard := 1 numNodesMeta := 1 - nodes, idxProposers := integrationTests.SetupSyncNodesOneShardAndMeta(numNodesPerShard, numNodesMeta) + nodes, leaders := integrationTests.SetupSyncNodesOneShardAndMeta(numNodesPerShard, numNodesMeta) defer integrationTests.CloseProcessorNodes(nodes) integrationTests.BootstrapDelay() @@ -1331,7 +1332,7 @@ func TestRollbackBlockAndCheckThatPruningIsCancelledOnAccountsTrie(t *testing.T) round = integrationTests.IncrementAndPrintRound(round) nonce++ - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) rootHashOfFirstBlock, _ := shardNode.AccntState.RootHash() @@ -1340,7 +1341,7 @@ func TestRollbackBlockAndCheckThatPruningIsCancelledOnAccountsTrie(t *testing.T) delayRounds := 10 for i := 0; i < delayRounds; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) } fmt.Println("Generating transactions...") @@ -1357,7 +1358,7 @@ func TestRollbackBlockAndCheckThatPruningIsCancelledOnAccountsTrie(t *testing.T) fmt.Println("Delaying for disseminating transactions...") time.Sleep(time.Second * 5) - round, _ = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, _ = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) time.Sleep(time.Second * 5) rootHashOfRollbackedBlock, _ := shardNode.AccntState.RootHash() @@ -1390,7 +1391,7 @@ func TestRollbackBlockAndCheckThatPruningIsCancelledOnAccountsTrie(t *testing.T) integrationTests.ProposeBlocks( nodes, &round, - idxProposers, + leaders, nonces, numOfRounds, ) @@ -1559,11 +1560,11 @@ func TestStatePruningIsNotBuffered(t *testing.T) { ) shardNode := nodes[0] - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1583,21 +1584,21 @@ func TestStatePruningIsNotBuffered(t *testing.T) { time.Sleep(integrationTests.StepDelay) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) delayRounds := 5 for j := 0; j < 8; j++ { // alter the shardNode's state by placing the value0 variable inside it's data trie alterState(t, shardNode, nodes, []byte("key"), []byte("value0")) for i := 0; i < delayRounds; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) } checkTrieCanBeRecreated(t, shardNode) // alter the shardNode's state by placing the value1 variable inside it's data trie alterState(t, shardNode, nodes, []byte("key"), []byte("value1")) for i := 0; i < delayRounds; i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) } checkTrieCanBeRecreated(t, shardNode) } @@ -1619,11 +1620,11 @@ func TestStatePruningIsNotBufferedOnConsecutiveBlocks(t *testing.T) { ) shardNode := nodes[0] - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1643,17 +1644,17 @@ func TestStatePruningIsNotBufferedOnConsecutiveBlocks(t *testing.T) { time.Sleep(integrationTests.StepDelay) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) for j := 0; j < 30; j++ { // alter the shardNode's state by placing the value0 variable inside it's data trie alterState(t, shardNode, nodes, []byte("key"), []byte("value0")) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) checkTrieCanBeRecreated(t, shardNode) // alter the shardNode's state by placing the value1 variable inside it's data trie alterState(t, shardNode, nodes, []byte("key"), []byte("value1")) - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) checkTrieCanBeRecreated(t, shardNode) } } @@ -1733,11 +1734,11 @@ func TestSnapshotOnEpochChange(t *testing.T) { node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1767,7 +1768,7 @@ func TestSnapshotOnEpochChange(t *testing.T) { numRounds := uint32(20) for i := uint64(0); i < uint64(numRounds); i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) for _, node := range nodes { integrationTests.CreateAndSendTransaction(node, nodes, sendValue, receiverAddress, "", integrationTests.AdditionalGasLimit) @@ -1786,7 +1787,7 @@ func TestSnapshotOnEpochChange(t *testing.T) { numDelayRounds := uint32(15) for i := uint64(0); i < uint64(numDelayRounds); i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) for _, node := range nodes { integrationTests.CreateAndSendTransaction(node, nodes, sendValue, receiverAddress, "", integrationTests.AdditionalGasLimit) @@ -2455,7 +2456,7 @@ func migrateDataTrieBuiltInFunc( migrationAddress []byte, nonce uint64, round uint64, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, ) { require.True(t, nodes[shardId].EnableEpochsHandler.IsFlagEnabled(common.AutoBalanceDataTriesFlag)) isMigrated := getAddressMigrationStatus(t, nodes[shardId].AccntState, migrationAddress) @@ -2465,7 +2466,7 @@ func migrateDataTrieBuiltInFunc( time.Sleep(time.Second) nrRoundsToPropagate := 5 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagate, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagate, nonce, round) isMigrated = getAddressMigrationStatus(t, nodes[shardId].AccntState, migrationAddress) require.True(t, isMigrated) @@ -2475,7 +2476,7 @@ func startNodesAndIssueToken( t *testing.T, numOfShards int, issuerShardId byte, -) ([]*integrationTests.TestProcessorNode, []int, uint64, uint64) { +) (leaders []*integrationTests.TestProcessorNode, nodes []*integrationTests.TestProcessorNode, nonce uint64, round uint64) { nodesPerShard := 1 numMetachainNodes := 1 @@ -2489,9 +2490,10 @@ func startNodesAndIssueToken( StakingV4Step1EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step2EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step3EnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, AutoBalanceDataTriesEnableEpoch: 1, } - nodes := integrationTests.CreateNodesWithEnableEpochs( + nodes = integrationTests.CreateNodesWithEnableEpochs( numOfShards, nodesPerShard, numMetachainNodes, @@ -2503,19 +2505,19 @@ func startNodesAndIssueToken( node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders = make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) initialVal := int64(10000000000) integrationTests.MintAllNodes(nodes, big.NewInt(initialVal)) - round := uint64(0) - nonce := uint64(0) + round = uint64(0) + nonce = uint64(0) round = integrationTests.IncrementAndPrintRound(round) nonce++ @@ -2526,14 +2528,14 @@ func startNodesAndIssueToken( time.Sleep(time.Second) nrRoundsToPropagate := 8 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagate, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagate, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) esdtCommon.CheckAddressHasTokens(t, nodes[issuerShardId].OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply) - return nodes, idxProposers, nonce, round + return nodes, leaders, nonce, round } func getDestAccountAddress(migrationAddress []byte, shardId byte) []byte { diff --git a/integrationTests/state/stateTrieSync/stateTrieSync_test.go b/integrationTests/state/stateTrieSync/stateTrieSync_test.go index 74650d4ce11..7ccc5255cb0 100644 --- a/integrationTests/state/stateTrieSync/stateTrieSync_test.go +++ b/integrationTests/state/stateTrieSync/stateTrieSync_test.go @@ -10,6 +10,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/throttler" + logger "github.com/multiversx/mx-chain-logger-go" + wasmConfig "github.com/multiversx/mx-chain-vm-go/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/common/holders" @@ -28,10 +33,6 @@ import ( "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/multiversx/mx-chain-go/trie/storageMarker" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" - logger "github.com/multiversx/mx-chain-logger-go" - wasmConfig "github.com/multiversx/mx-chain-vm-go/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var log = logger.GetOrCreate("integrationtests/state/statetriesync") @@ -449,11 +450,11 @@ func testSyncMissingSnapshotNodes(t *testing.T, version int) { node.EpochStartTrigger.SetRoundsPerEpoch(roundsPerEpoch) } - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -476,7 +477,7 @@ func testSyncMissingSnapshotNodes(t *testing.T, version int) { nonce++ numDelayRounds := uint32(10) for i := uint64(0); i < uint64(numDelayRounds); i++ { - round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = integrationTests.ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) time.Sleep(integrationTests.StepDelay) } diff --git a/integrationTests/sync/basicSync/basicSync_test.go b/integrationTests/sync/basicSync/basicSync_test.go index 52cc2c7af79..2d75ec9e10f 100644 --- a/integrationTests/sync/basicSync/basicSync_test.go +++ b/integrationTests/sync/basicSync/basicSync_test.go @@ -8,9 +8,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" - "github.com/multiversx/mx-chain-go/integrationTests" logger "github.com/multiversx/mx-chain-logger-go" "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/integrationTests" ) var log = logger.GetOrCreate("basicSync") @@ -19,7 +20,6 @@ func TestSyncWorksInShard_EmptyBlocksNoForks(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - maxShards := uint32(1) shardId := uint32(0) numNodesPerShard := 6 @@ -47,7 +47,7 @@ func TestSyncWorksInShard_EmptyBlocksNoForks(t *testing.T) { connectableNodes = append(connectableNodes, metachainNode) idxProposerShard0 := 0 - idxProposers := []int{idxProposerShard0, idxProposerMeta} + leaders := []*integrationTests.TestProcessorNode{nodes[idxProposerShard0], nodes[idxProposerMeta]} integrationTests.ConnectNodes(connectableNodes) @@ -72,7 +72,7 @@ func TestSyncWorksInShard_EmptyBlocksNoForks(t *testing.T) { numRoundsToTest := 5 for i := 0; i < numRoundsToTest; i++ { - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) time.Sleep(integrationTests.SyncDelay) @@ -110,7 +110,7 @@ func TestSyncWorksInShard_EmptyBlocksDoubleSign(t *testing.T) { integrationTests.ConnectNodes(connectableNodes) idxProposerShard0 := 0 - idxProposers := []int{idxProposerShard0} + leaders := []*integrationTests.TestProcessorNode{nodes[idxProposerShard0]} defer func() { for _, n := range nodes { @@ -133,7 +133,7 @@ func TestSyncWorksInShard_EmptyBlocksDoubleSign(t *testing.T) { numRoundsToTest := 2 for i := 0; i < numRoundsToTest; i++ { - integrationTests.ProposeBlock(nodes, idxProposers, round, nonce) + integrationTests.ProposeBlock(nodes, leaders, round, nonce) time.Sleep(integrationTests.SyncDelay) @@ -197,3 +197,179 @@ func testAllNodesHaveSameLastBlock(t *testing.T, nodes []*integrationTests.TestP assert.Equal(t, 1, len(mapBlocksByHash)) } + +func TestSyncWorksInShard_EmptyBlocksNoForks_With_EquivalentProofs(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + // 3 shard nodes and 1 metachain node + maxShards := uint32(1) + shardId := uint32(0) + numNodesPerShard := 3 + + enableEpochs := integrationTests.CreateEnableEpochsConfig() + enableEpochs.AndromedaEnableEpoch = uint32(0) + + nodes := make([]*integrationTests.TestProcessorNode, numNodesPerShard+1) + connectableNodes := make([]integrationTests.Connectable, 0) + for i := 0; i < numNodesPerShard; i++ { + nodes[i] = integrationTests.NewTestProcessorNode(integrationTests.ArgTestProcessorNode{ + MaxShards: maxShards, + NodeShardId: shardId, + TxSignPrivKeyShardId: shardId, + WithSync: true, + EpochsConfig: &enableEpochs, + }) + connectableNodes = append(connectableNodes, nodes[i]) + } + + metachainNode := integrationTests.NewTestProcessorNode(integrationTests.ArgTestProcessorNode{ + MaxShards: maxShards, + NodeShardId: core.MetachainShardId, + TxSignPrivKeyShardId: shardId, + WithSync: true, + EpochsConfig: &enableEpochs, + }) + idxProposerMeta := numNodesPerShard + nodes[idxProposerMeta] = metachainNode + connectableNodes = append(connectableNodes, metachainNode) + + idxProposerShard0 := 0 + leaders := []*integrationTests.TestProcessorNode{nodes[idxProposerShard0], nodes[idxProposerMeta]} + + integrationTests.ConnectNodes(connectableNodes) + + defer func() { + for _, n := range nodes { + n.Close() + } + }() + + for _, n := range nodes { + _ = n.StartSync() + } + + fmt.Println("Delaying for nodes p2p bootstrap...") + time.Sleep(integrationTests.P2pBootstrapDelay) + + round := uint64(0) + nonce := uint64(0) + round = integrationTests.IncrementAndPrintRound(round) + integrationTests.UpdateRound(nodes, round) + nonce++ + + numRoundsToTest := 5 + + for i := 0; i < numRoundsToTest; i++ { + integrationTests.ProposeBlockWithProof(nodes, leaders, round, nonce) + + time.Sleep(integrationTests.SyncDelay) + + round = integrationTests.IncrementAndPrintRound(round) + integrationTests.UpdateRound(nodes, round) + nonce++ + } + + time.Sleep(integrationTests.SyncDelay) + + expectedNonce := nodes[0].BlockChain.GetCurrentBlockHeader().GetNonce() + for i := 1; i < len(nodes); i++ { + if check.IfNil(nodes[i].BlockChain.GetCurrentBlockHeader()) { + assert.Fail(t, fmt.Sprintf("Node with idx %d does not have a current block", i)) + } else { + // all nodes must have proofs now + assert.Equal(t, expectedNonce, nodes[i].BlockChain.GetCurrentBlockHeader().GetNonce()) + } + } +} + +func TestSyncMetaAndShard_With_EquivalentProofs(t *testing.T) { + if testing.Short() { + t.Skip("this is not a short test") + } + + // 3 shard nodes and 3 metachain node + maxShards := uint32(1) + shardId := uint32(0) + numNodesPerShard := 3 + + enableEpochs := integrationTests.CreateEnableEpochsConfig() + enableEpochs.AndromedaEnableEpoch = uint32(0) + + nodes := make([]*integrationTests.TestProcessorNode, 2*numNodesPerShard) + leaders := make([]*integrationTests.TestProcessorNode, 0) + connectableNodes := make([]integrationTests.Connectable, 0) + + for i := 0; i < numNodesPerShard; i++ { + nodes[i] = integrationTests.NewTestProcessorNode(integrationTests.ArgTestProcessorNode{ + MaxShards: maxShards, + NodeShardId: shardId, + TxSignPrivKeyShardId: shardId, + WithSync: true, + EpochsConfig: &enableEpochs, + }) + connectableNodes = append(connectableNodes, nodes[i]) + } + + idxProposerShard0 := 0 + leaders = append(leaders, nodes[idxProposerShard0]) + + idxProposerMeta := numNodesPerShard + for i := 0; i < numNodesPerShard; i++ { + metachainNode := integrationTests.NewTestProcessorNode(integrationTests.ArgTestProcessorNode{ + MaxShards: maxShards, + NodeShardId: core.MetachainShardId, + TxSignPrivKeyShardId: shardId, + WithSync: true, + EpochsConfig: &enableEpochs, + }) + nodes[idxProposerMeta+i] = metachainNode + connectableNodes = append(connectableNodes, metachainNode) + } + leaders = append(leaders, nodes[idxProposerMeta]) + + integrationTests.ConnectNodes(connectableNodes) + + defer func() { + for _, n := range nodes { + n.Close() + } + }() + + for _, n := range nodes { + _ = n.StartSync() + } + + fmt.Println("Delaying for nodes p2p bootstrap...") + time.Sleep(integrationTests.P2pBootstrapDelay) + + round := uint64(0) + nonce := uint64(0) + round = integrationTests.IncrementAndPrintRound(round) + integrationTests.UpdateRound(nodes, round) + nonce++ + + numRoundsToTest := 5 + for i := 0; i < numRoundsToTest; i++ { + integrationTests.ProposeBlockWithProof(nodes, leaders, round, nonce) + + time.Sleep(integrationTests.SyncDelay) + + round = integrationTests.IncrementAndPrintRound(round) + integrationTests.UpdateRound(nodes, round) + nonce++ + } + + time.Sleep(integrationTests.SyncDelay) + + expectedNonce := nodes[0].BlockChain.GetCurrentBlockHeader().GetNonce() + for i := 1; i < len(nodes); i++ { + if check.IfNil(nodes[i].BlockChain.GetCurrentBlockHeader()) { + assert.Fail(t, fmt.Sprintf("Node with idx %d does not have a current block", i)) + } else { + // all nodes must have proofs now + assert.Equal(t, expectedNonce, nodes[i].BlockChain.GetCurrentBlockHeader().GetNonce()) + } + } +} diff --git a/integrationTests/sync/edgeCases/edgeCases_test.go b/integrationTests/sync/edgeCases/edgeCases_test.go index f3167b0528e..285fed4dd8c 100644 --- a/integrationTests/sync/edgeCases/edgeCases_test.go +++ b/integrationTests/sync/edgeCases/edgeCases_test.go @@ -6,9 +6,10 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" - "github.com/multiversx/mx-chain-go/integrationTests" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/integrationTests" ) // TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard tests the following scenario: @@ -24,8 +25,8 @@ func TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard(t *testing.T) { numNodesPerShard := 3 numNodesMeta := 3 - nodes, idxProposers := integrationTests.SetupSyncNodesOneShardAndMeta(numNodesPerShard, numNodesMeta) - idxProposerMeta := idxProposers[1] + nodes, leaders := integrationTests.SetupSyncNodesOneShardAndMeta(numNodesPerShard, numNodesMeta) + leaderMeta := leaders[1] defer integrationTests.CloseProcessorNodes(nodes) integrationTests.BootstrapDelay() @@ -44,7 +45,7 @@ func TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard(t *testing.T) { integrationTests.ProposeBlocks( nodes, &round, - idxProposers, + leaders, nonces, numRoundsBlocksAreProposedCorrectly, ) @@ -54,14 +55,14 @@ func TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard(t *testing.T) { integrationTests.ResetHighestProbableNonce(nodes, shardIdToRollbackLastBlock, 2) integrationTests.EmptyDataPools(nodes, shardIdToRollbackLastBlock) - //revert also the nonce, so the same block nonce will be used when shard will propose the next block + // revert also the nonce, so the same block nonce will be used when shard will propose the next block atomic.AddUint64(nonces[idxNonceShard], ^uint64(0)) numRoundsBlocksAreProposedOnlyByMeta := 2 integrationTests.ProposeBlocks( nodes, &round, - []int{idxProposerMeta}, + []*integrationTests.TestProcessorNode{leaderMeta}, []*uint64{nonces[idxNonceMeta]}, numRoundsBlocksAreProposedOnlyByMeta, ) @@ -70,7 +71,7 @@ func TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard(t *testing.T) { integrationTests.ProposeBlocks( nodes, &round, - idxProposers, + leaders, nonces, secondNumRoundsBlocksAreProposedCorrectly, ) @@ -99,12 +100,12 @@ func TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard(t *testing.T) { integrationTests.StartSyncingBlocks(syncNodesSlice) - //after joining the network we must propose a new block on the metachain as to be received by the sync - //node and to start the bootstrapping process + // after joining the network we must propose a new block on the metachain as to be received by the sync + // node and to start the bootstrapping process integrationTests.ProposeBlocks( nodes, &round, - []int{idxProposerMeta}, + []*integrationTests.TestProcessorNode{leaderMeta}, []*uint64{nonces[idxNonceMeta]}, 1, ) @@ -115,7 +116,7 @@ func TestSyncMetaNodeIsSyncingReceivedHigherRoundBlockFromShard(t *testing.T) { time.Sleep(integrationTests.SyncDelay * time.Duration(numOfRoundsToWaitToCatchUp)) integrationTests.UpdateRound(nodes, round) - nonceProposerMeta := nodes[idxProposerMeta].BlockChain.GetCurrentBlockHeader().GetNonce() + nonceProposerMeta := leaderMeta.BlockChain.GetCurrentBlockHeader().GetNonce() nonceSyncNode := syncMetaNode.BlockChain.GetCurrentBlockHeader().GetNonce() assert.Equal(t, nonceProposerMeta, nonceSyncNode) } diff --git a/integrationTests/testConsensusNode.go b/integrationTests/testConsensusNode.go index 5f5987b11cf..282d14b6bbd 100644 --- a/integrationTests/testConsensusNode.go +++ b/integrationTests/testConsensusNode.go @@ -16,11 +16,18 @@ import ( crypto "github.com/multiversx/mx-chain-crypto-go" mclMultiSig "github.com/multiversx/mx-chain-crypto-go/signing/mcl/multisig" "github.com/multiversx/mx-chain-crypto-go/signing/multisig" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/round" "github.com/multiversx/mx-chain-go/dataRetriever" + epochStartDisabled "github.com/multiversx/mx-chain-go/epochStart/bootstrap/disabled" "github.com/multiversx/mx-chain-go/epochStart/metachain" "github.com/multiversx/mx-chain-go/epochStart/notifier" + "github.com/multiversx/mx-chain-go/epochStart/shardchain" cryptoFactory "github.com/multiversx/mx-chain-go/factory/crypto" "github.com/multiversx/mx-chain-go/factory/peerSignatureHandler" "github.com/multiversx/mx-chain-go/integrationTests/mock" @@ -29,7 +36,14 @@ import ( "github.com/multiversx/mx-chain-go/ntp" "github.com/multiversx/mx-chain-go/p2p" p2pFactory "github.com/multiversx/mx-chain-go/p2p/factory" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" + "github.com/multiversx/mx-chain-go/process/factory/interceptorscontainer" + "github.com/multiversx/mx-chain-go/process/interceptors" + disabledInterceptors "github.com/multiversx/mx-chain-go/process/interceptors/disabled" + interceptorsFactory "github.com/multiversx/mx-chain-go/process/interceptors/factory" + processMock "github.com/multiversx/mx-chain-go/process/mock" + "github.com/multiversx/mx-chain-go/process/smartContract" syncFork "github.com/multiversx/mx-chain-go/process/sync" "github.com/multiversx/mx-chain-go/sharding" chainShardingMocks "github.com/multiversx/mx-chain-go/sharding/mock" @@ -39,8 +53,11 @@ import ( "github.com/multiversx/mx-chain-go/storage/cache" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" testFactory "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" @@ -63,32 +80,36 @@ var testPubkeyConverter, _ = pubkeyConverter.NewHexPubkeyConverter(32) // ArgsTestConsensusNode represents the arguments for the test consensus node constructor(s) type ArgsTestConsensusNode struct { - ShardID uint32 - ConsensusSize int - RoundTime uint64 - ConsensusType string - NodeKeys *TestNodeKeys - EligibleMap map[uint32][]nodesCoordinator.Validator - WaitingMap map[uint32][]nodesCoordinator.Validator - KeyGen crypto.KeyGenerator - P2PKeyGen crypto.KeyGenerator - MultiSigner *cryptoMocks.MultisignerMock - StartTime int64 + ShardID uint32 + ConsensusSize int + RoundTime uint64 + ConsensusType string + NodeKeys *TestNodeKeys + EligibleMap map[uint32][]nodesCoordinator.Validator + WaitingMap map[uint32][]nodesCoordinator.Validator + KeyGen crypto.KeyGenerator + P2PKeyGen crypto.KeyGenerator + MultiSigner *cryptoMocks.MultisignerMock + StartTime int64 + EnableEpochsConfig config.EnableEpochs } // TestConsensusNode represents a structure used in integration tests used for consensus tests type TestConsensusNode struct { - Node *node.Node - MainMessenger p2p.Messenger - FullArchiveMessenger p2p.Messenger - NodesCoordinator nodesCoordinator.NodesCoordinator - ShardCoordinator sharding.Coordinator - ChainHandler data.ChainHandler - BlockProcessor *mock.BlockProcessorMock - RequestersFinder dataRetriever.RequestersFinder - AccountsDB *state.AccountsDB - NodeKeys *TestKeyPair - MultiSigner *cryptoMocks.MultisignerMock + Node *node.Node + MainMessenger p2p.Messenger + FullArchiveMessenger p2p.Messenger + NodesCoordinator nodesCoordinator.NodesCoordinator + ShardCoordinator sharding.Coordinator + ChainHandler data.ChainHandler + BlockProcessor *mock.BlockProcessorMock + RequestersFinder dataRetriever.RequestersFinder + AccountsDB *state.AccountsDB + NodeKeys *TestKeyPair + MultiSigner *cryptoMocks.MultisignerMock + MainInterceptorsContainer process.InterceptorsContainer + DataPool dataRetriever.PoolsHolder + RequestHandler process.RequestHandler } // NewTestConsensusNode returns a new TestConsensusNode @@ -114,6 +135,7 @@ func CreateNodesWithTestConsensusNode( roundTime uint64, consensusType string, numKeysOnEachNode int, + enableEpochsConfig config.EnableEpochs, ) map[uint32][]*TestConsensusNode { nodes := make(map[uint32][]*TestConsensusNode, nodesPerShard) @@ -133,17 +155,18 @@ func CreateNodesWithTestConsensusNode( multiSignerMock := createCustomMultiSignerMock(multiSigner) args := ArgsTestConsensusNode{ - ShardID: shardID, - ConsensusSize: consensusSize, - RoundTime: roundTime, - ConsensusType: consensusType, - NodeKeys: keysPair, - EligibleMap: eligibleMap, - WaitingMap: waitingMap, - KeyGen: cp.KeyGen, - P2PKeyGen: cp.P2PKeyGen, - MultiSigner: multiSignerMock, - StartTime: startTime, + ShardID: shardID, + ConsensusSize: consensusSize, + RoundTime: roundTime, + ConsensusType: consensusType, + NodeKeys: keysPair, + EligibleMap: eligibleMap, + WaitingMap: waitingMap, + KeyGen: cp.KeyGen, + P2PKeyGen: cp.P2PKeyGen, + MultiSigner: multiSignerMock, + StartTime: startTime, + EnableEpochsConfig: enableEpochsConfig, } tcn := NewTestConsensusNode(args) @@ -178,6 +201,8 @@ func createCustomMultiSignerMock(multiSigner crypto.MultiSigner) *cryptoMocks.Mu } func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { + var err error + testHasher := createHasher(args.ConsensusType) epochStartRegistrationHandler := notifier.NewEpochStartSubscriptionHandler() consensusCache, _ := cache.NewLRUCache(10000) @@ -187,11 +212,18 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { tcn.MainMessenger = CreateMessengerWithNoDiscovery() tcn.FullArchiveMessenger = &p2pmocks.MessengerStub{} tcn.initBlockChain(testHasher) - tcn.initBlockProcessor() + tcn.initBlockProcessor(tcn.ShardCoordinator.SelfId()) syncer := ntp.NewSyncTime(ntp.NewNTPGoogleConfig(), nil) syncer.StartSyncingTime() + genericEpochNotifier := forking.NewGenericEpochNotifier() + + epochsConfig := GetDefaultEnableEpochsConfig() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(*epochsConfig, genericEpochNotifier) + + storage := CreateStore(tcn.ShardCoordinator.NumberOfShards()) + roundHandler, _ := round.NewRound( time.Unix(args.StartTime, 0), syncer.CurrentTime(), @@ -200,29 +232,63 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { 0) dataPool := dataRetrieverMock.CreatePoolsHolder(1, 0) + tcn.DataPool = dataPool - argsNewMetaEpochStart := &metachain.ArgsNewMetaEpochStartTrigger{ - GenesisTime: time.Unix(args.StartTime, 0), - EpochStartNotifier: notifier.NewEpochStartSubscriptionHandler(), - Settings: &config.EpochStartConfig{ - MinRoundsBetweenEpochs: 1, - RoundsPerEpoch: 1000, - }, - Epoch: 0, - Storage: createTestStore(), - Marshalizer: TestMarshalizer, - Hasher: testHasher, - AppStatusHandler: &statusHandlerMock.AppStatusHandlerStub{}, - DataPool: dataPool, + var epochTrigger TestEpochStartTrigger + if tcn.ShardCoordinator.SelfId() == core.MetachainShardId { + argsNewMetaEpochStart := &metachain.ArgsNewMetaEpochStartTrigger{ + GenesisTime: time.Unix(args.StartTime, 0), + EpochStartNotifier: notifier.NewEpochStartSubscriptionHandler(), + Settings: &config.EpochStartConfig{ + MinRoundsBetweenEpochs: 1, + RoundsPerEpoch: 1000, + }, + Epoch: 0, + Storage: createTestStore(), + Marshalizer: TestMarshalizer, + Hasher: testHasher, + AppStatusHandler: &statusHandlerMock.AppStatusHandlerStub{}, + DataPool: dataPool, + } + epochStartTrigger, err := metachain.NewEpochStartTrigger(argsNewMetaEpochStart) + if err != nil { + fmt.Println(err.Error()) + } + epochTrigger = &metachain.TestTrigger{} + epochTrigger.SetTrigger(epochStartTrigger) + } else { + argsPeerMiniBlocksSyncer := shardchain.ArgPeerMiniBlockSyncer{ + MiniBlocksPool: tcn.DataPool.MiniBlocks(), + ValidatorsInfoPool: tcn.DataPool.ValidatorsInfo(), + RequestHandler: &testscommon.RequestHandlerStub{}, + } + peerMiniBlockSyncer, _ := shardchain.NewPeerMiniBlockSyncer(argsPeerMiniBlocksSyncer) + + argsShardEpochStart := &shardchain.ArgsShardEpochStartTrigger{ + Marshalizer: TestMarshalizer, + Hasher: TestHasher, + HeaderValidator: &mock.HeaderValidatorStub{}, + Uint64Converter: TestUint64Converter, + DataPool: tcn.DataPool, + Storage: storage, + RequestHandler: &testscommon.RequestHandlerStub{}, + Epoch: 0, + Validity: 1, + Finality: 1, + EpochStartNotifier: notifier.NewEpochStartSubscriptionHandler(), + PeerMiniBlocksSyncer: peerMiniBlockSyncer, + RoundHandler: roundHandler, + AppStatusHandler: &statusHandlerMock.AppStatusHandlerStub{}, + EnableEpochsHandler: enableEpochsHandler, + } + epochStartTrigger, err := shardchain.NewEpochStartTrigger(argsShardEpochStart) + if err != nil { + fmt.Println("NewEpochStartTrigger shard") + fmt.Println(err.Error()) + } + epochTrigger = &shardchain.TestTrigger{} + epochTrigger.SetTrigger(epochStartTrigger) } - epochStartTrigger, _ := metachain.NewEpochStartTrigger(argsNewMetaEpochStart) - - forkDetector, _ := syncFork.NewShardForkDetector( - roundHandler, - cache.NewTimeCache(time.Second), - &mock.BlockTrackerStub{}, - args.StartTime, - ) tcn.initRequestersFinder() @@ -235,7 +301,9 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { tcn.initAccountsDB() - coreComponents := GetDefaultCoreComponents(CreateEnableEpochsConfig()) + genericEpochNotifier = forking.NewGenericEpochNotifier() + enableEpochsHandler, _ = enablers.NewEnableEpochsHandler(args.EnableEpochsConfig, genericEpochNotifier) + coreComponents := GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.SyncTimerField = syncer coreComponents.RoundHandlerField = roundHandler coreComponents.InternalMarshalizerField = TestMarshalizer @@ -253,11 +321,12 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { return uint32(args.ConsensusSize) }, } + coreComponents.HardforkTriggerPubKeyField = []byte("provided hardfork pub key") argsKeysHolder := keysManagement.ArgsManagedPeersHolder{ KeyGenerator: args.KeyGen, P2PKeyGenerator: args.P2PKeyGen, - MaxRoundsOfInactivity: 0, + MaxRoundsOfInactivity: 0, // 0 for main node, non-0 for backup node PrefsConfig: config.Preferences{}, P2PKeyConverter: p2pFactory.NewP2PKeyConverter(), } @@ -303,17 +372,26 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { cryptoComponents.SigHandler = sigHandler cryptoComponents.KeysHandlerField = keysHandler + forkDetector, _ := syncFork.NewShardForkDetector( + roundHandler, + cache.NewTimeCache(time.Second), + &mock.BlockTrackerStub{}, + args.StartTime, + enableEpochsHandler, + dataPool.Proofs(), + ) + processComponents := GetDefaultProcessComponents() processComponents.ForkDetect = forkDetector processComponents.ShardCoord = tcn.ShardCoordinator processComponents.NodesCoord = tcn.NodesCoordinator processComponents.BlockProcess = tcn.BlockProcessor processComponents.ReqFinder = tcn.RequestersFinder - processComponents.EpochTrigger = epochStartTrigger + processComponents.EpochTrigger = epochTrigger processComponents.EpochNotifier = epochStartRegistrationHandler processComponents.BlackListHdl = &testscommon.TimeCacheStub{} processComponents.BootSore = &mock.BoostrapStorerMock{} - processComponents.HeaderSigVerif = &mock.HeaderSigVerifierStub{} + processComponents.HeaderSigVerif = &consensusMocks.HeaderSigVerifierMock{} processComponents.HeaderIntegrVerif = &mock.HeaderIntegrityVerifierStub{} processComponents.ReqHandler = &testscommon.RequestHandlerStub{} processComponents.MainPeerMapper = mock.NewNetworkShardingCollectorMock() @@ -323,6 +401,9 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { processComponents.ProcessedMiniBlocksTrackerInternal = &testscommon.ProcessedMiniBlocksTrackerStub{} processComponents.SentSignaturesTrackerInternal = &testscommon.SentSignatureTrackerStub{} + tcn.initInterceptors(coreComponents, cryptoComponents, roundHandler, enableEpochsHandler, storage, epochTrigger) + processComponents.IntContainer = tcn.MainInterceptorsContainer + dataComponents := GetDefaultDataComponents() dataComponents.BlockChain = tcn.ChainHandler dataComponents.DataPool = dataPool @@ -336,7 +417,6 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { AppStatusHandlerField: &statusHandlerMock.AppStatusHandlerStub{}, } - var err error tcn.Node, err = node.NewNode( node.WithCoreComponents(coreComponents), node.WithStatusCoreComponents(statusCoreComponents), @@ -346,7 +426,6 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { node.WithStateComponents(stateComponents), node.WithNetworkComponents(networkComponents), node.WithRoundDuration(args.RoundTime), - node.WithConsensusGroupSize(args.ConsensusSize), node.WithConsensusType(args.ConsensusType), node.WithGenesisTime(time.Unix(args.StartTime, 0)), node.WithValidatorSignatureSize(signatureSize), @@ -358,6 +437,113 @@ func (tcn *TestConsensusNode) initNode(args ArgsTestConsensusNode) { } } +func (tcn *TestConsensusNode) initInterceptors( + coreComponents process.CoreComponentsHolder, + cryptoComponents process.CryptoComponentsHolder, + roundHandler consensus.RoundHandler, + enableEpochsHandler common.EnableEpochsHandler, + storage dataRetriever.StorageService, + epochStartTrigger TestEpochStartTrigger, +) { + interceptorDataVerifierArgs := interceptorsFactory.InterceptedDataVerifierFactoryArgs{ + CacheSpan: time.Second * 10, + CacheExpiry: time.Second * 10, + } + + accountsAdapter := epochStartDisabled.NewAccountsAdapter() + + blockBlackListHandler := cache.NewTimeCache(TimeSpanForBadHeaders) + + genesisBlocks := make(map[uint32]data.HeaderHandler) + blockTracker := processMock.NewBlockTrackerMock(tcn.ShardCoordinator, genesisBlocks) + + whiteLstHandler, _ := disabledInterceptors.NewDisabledWhiteListDataVerifier() + + cacherVerifiedCfg := storageunit.CacheConfig{Capacity: 5000, Type: storageunit.LRUCache, Shards: 1} + cacheVerified, _ := storageunit.NewCache(cacherVerifiedCfg) + whiteListerVerifiedTxs, _ := interceptors.NewWhiteListDataVerifier(cacheVerified) + + interceptorContainerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ + CoreComponents: coreComponents, + CryptoComponents: cryptoComponents, + Accounts: accountsAdapter, + ShardCoordinator: tcn.ShardCoordinator, + NodesCoordinator: tcn.NodesCoordinator, + MainMessenger: tcn.MainMessenger, + FullArchiveMessenger: tcn.FullArchiveMessenger, + Store: storage, + DataPool: tcn.DataPool, + MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, + TxFeeHandler: &economicsmocks.EconomicsHandlerMock{}, + BlockBlackList: blockBlackListHandler, + HeaderSigVerifier: &consensusMocks.HeaderSigVerifierMock{}, + HeaderIntegrityVerifier: CreateHeaderIntegrityVerifier(), + ValidityAttester: blockTracker, + EpochStartTrigger: epochStartTrigger, + WhiteListHandler: whiteLstHandler, + WhiteListerVerifiedTxs: whiteListerVerifiedTxs, + AntifloodHandler: &mock.NilAntifloodHandler{}, + ArgumentsParser: smartContract.NewArgumentParser(), + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + SizeCheckDelta: sizeCheckDelta, + RequestHandler: &testscommon.RequestHandlerStub{}, + PeerSignatureHandler: &processMock.PeerSignatureHandlerStub{}, + SignaturesHandler: &processMock.SignaturesHandlerStub{}, + HeartbeatExpiryTimespanInSec: 30, + MainPeerShardMapper: mock.NewNetworkShardingCollectorMock(), + FullArchivePeerShardMapper: mock.NewNetworkShardingCollectorMock(), + HardforkTrigger: &testscommon.HardforkTriggerStub{}, + NodeOperationMode: common.NormalOperation, + InterceptedDataVerifierFactory: interceptorsFactory.NewInterceptedDataVerifierFactory(interceptorDataVerifierArgs), + } + if tcn.ShardCoordinator.SelfId() == core.MetachainShardId { + interceptorContainerFactory, err := interceptorscontainer.NewMetaInterceptorsContainerFactory(interceptorContainerFactoryArgs) + if err != nil { + fmt.Println(err.Error()) + } + + tcn.MainInterceptorsContainer, _, err = interceptorContainerFactory.Create() + if err != nil { + log.Debug("interceptor container factory Create", "error", err.Error()) + } + } else { + argsPeerMiniBlocksSyncer := shardchain.ArgPeerMiniBlockSyncer{ + MiniBlocksPool: tcn.DataPool.MiniBlocks(), + ValidatorsInfoPool: tcn.DataPool.ValidatorsInfo(), + RequestHandler: &testscommon.RequestHandlerStub{}, + } + peerMiniBlockSyncer, _ := shardchain.NewPeerMiniBlockSyncer(argsPeerMiniBlocksSyncer) + argsShardEpochStart := &shardchain.ArgsShardEpochStartTrigger{ + Marshalizer: TestMarshalizer, + Hasher: TestHasher, + HeaderValidator: &mock.HeaderValidatorStub{}, + Uint64Converter: TestUint64Converter, + DataPool: tcn.DataPool, + Storage: storage, + RequestHandler: &testscommon.RequestHandlerStub{}, + Epoch: 0, + Validity: 1, + Finality: 1, + EpochStartNotifier: notifier.NewEpochStartSubscriptionHandler(), + PeerMiniBlocksSyncer: peerMiniBlockSyncer, + RoundHandler: roundHandler, + AppStatusHandler: &statusHandlerMock.AppStatusHandlerStub{}, + EnableEpochsHandler: enableEpochsHandler, + } + _, _ = shardchain.NewEpochStartTrigger(argsShardEpochStart) + + interceptorContainerFactory, err := interceptorscontainer.NewShardInterceptorsContainerFactory(interceptorContainerFactoryArgs) + if err != nil { + fmt.Println(err.Error()) + } + + tcn.MainInterceptorsContainer, _, err = interceptorContainerFactory.Create() + if err != nil { + fmt.Println(err.Error()) + } + } +} + func (tcn *TestConsensusNode) initNodesCoordinator( consensusSize int, hasher hashing.Hasher, @@ -368,8 +554,14 @@ func (tcn *TestConsensusNode) initNodesCoordinator( cache storage.Cacher, ) { argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ - ShardConsensusGroupSize: consensusSize, - MetaConsensusGroupSize: consensusSize, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(consensusSize), + MetachainConsensusGroupSize: uint32(consensusSize), + }, nil + }, + }, Marshalizer: TestMarshalizer, Hasher: hasher, Shuffler: &shardingMocks.NodeShufflerMock{}, @@ -429,7 +621,7 @@ func (tcn *TestConsensusNode) initBlockChain(hasher hashing.Hasher) { tcn.ChainHandler.SetGenesisHeaderHash(hasher.Compute(string(hdrMarshalized))) } -func (tcn *TestConsensusNode) initBlockProcessor() { +func (tcn *TestConsensusNode) initBlockProcessor(shardId uint32) { tcn.BlockProcessor = &mock.BlockProcessorMock{ Marshalizer: TestMarshalizer, CommitBlockCalled: func(header data.HeaderHandler, body data.BodyHandler) error { @@ -453,12 +645,38 @@ func (tcn *TestConsensusNode) initBlockProcessor() { return mrsData, mrsTxs, nil }, CreateNewHeaderCalled: func(round uint64, nonce uint64) (data.HeaderHandler, error) { - return &dataBlock.Header{ - Round: round, - Nonce: nonce, - SoftwareVersion: []byte("version"), + if shardId == common.MetachainShardId { + return &dataBlock.MetaBlock{ + Round: round, + Nonce: nonce, + SoftwareVersion: []byte("version"), + ValidatorStatsRootHash: []byte("validator stats root hash"), + AccumulatedFeesInEpoch: big.NewInt(0), + DeveloperFees: big.NewInt(0), + DevFeesInEpoch: big.NewInt(0), + }, nil + } + + return &dataBlock.HeaderV2{ + Header: &dataBlock.Header{ + Round: round, + Nonce: nonce, + SoftwareVersion: []byte("version"), + }, + ScheduledDeveloperFees: big.NewInt(0), + ScheduledAccumulatedFees: big.NewInt(0), }, nil }, + DecodeBlockHeaderCalled: func(dta []byte) data.HeaderHandler { + var header data.HeaderHandler + header = &dataBlock.HeaderV2{} + if shardId == common.MetachainShardId { + header = &dataBlock.MetaBlock{} + } + + _ = TestMarshalizer.Unmarshal(header, dta) + return header + }, } } diff --git a/integrationTests/testFullNode.go b/integrationTests/testFullNode.go new file mode 100644 index 00000000000..5b325eb4180 --- /dev/null +++ b/integrationTests/testFullNode.go @@ -0,0 +1,1176 @@ +package integrationTests + +import ( + "encoding/hex" + "fmt" + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/core/versioning" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/endProcess" + "github.com/multiversx/mx-chain-core-go/hashing" + crypto "github.com/multiversx/mx-chain-crypto-go" + mclMultiSig "github.com/multiversx/mx-chain-crypto-go/signing/mcl/multisig" + "github.com/multiversx/mx-chain-crypto-go/signing/multisig" + wasmConfig "github.com/multiversx/mx-chain-vm-go/config" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/forking" + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/round" + "github.com/multiversx/mx-chain-go/consensus/spos/sposFactory" + "github.com/multiversx/mx-chain-go/dataRetriever" + "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" + epochStartDisabled "github.com/multiversx/mx-chain-go/epochStart/bootstrap/disabled" + "github.com/multiversx/mx-chain-go/epochStart/metachain" + "github.com/multiversx/mx-chain-go/epochStart/notifier" + "github.com/multiversx/mx-chain-go/epochStart/shardchain" + cryptoFactory "github.com/multiversx/mx-chain-go/factory/crypto" + "github.com/multiversx/mx-chain-go/factory/peerSignatureHandler" + "github.com/multiversx/mx-chain-go/integrationTests/mock" + "github.com/multiversx/mx-chain-go/keysManagement" + "github.com/multiversx/mx-chain-go/node" + "github.com/multiversx/mx-chain-go/node/nodeDebugFactory" + "github.com/multiversx/mx-chain-go/ntp" + p2pFactory "github.com/multiversx/mx-chain-go/p2p/factory" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/process/block" + "github.com/multiversx/mx-chain-go/process/block/bootstrapStorage" + "github.com/multiversx/mx-chain-go/process/factory/interceptorscontainer" + "github.com/multiversx/mx-chain-go/process/interceptors" + disabledInterceptors "github.com/multiversx/mx-chain-go/process/interceptors/disabled" + interceptorsFactory "github.com/multiversx/mx-chain-go/process/interceptors/factory" + processMock "github.com/multiversx/mx-chain-go/process/mock" + "github.com/multiversx/mx-chain-go/process/scToProtocol" + "github.com/multiversx/mx-chain-go/process/smartContract" + processSync "github.com/multiversx/mx-chain-go/process/sync" + "github.com/multiversx/mx-chain-go/process/track" + "github.com/multiversx/mx-chain-go/sharding" + chainShardingMocks "github.com/multiversx/mx-chain-go/sharding/mock" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" + "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/state/blockInfoProviders" + "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/storage/cache" + "github.com/multiversx/mx-chain-go/storage/storageunit" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/factory" + testFactory "github.com/multiversx/mx-chain-go/testscommon/factory" + "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" + "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" + "github.com/multiversx/mx-chain-go/testscommon/outport" + "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" + statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" + vic "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" + "github.com/multiversx/mx-chain-go/vm" + "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" +) + +// CreateNodesWithTestFullNode will create a set of nodes with full consensus and processing components +func CreateNodesWithTestFullNode( + numMetaNodes int, + nodesPerShard int, + consensusSize int, + roundTime uint64, + consensusType string, + numKeysOnEachNode int, + enableEpochsConfig config.EnableEpochs, + withSync bool, +) map[uint32][]*TestFullNode { + + nodes := make(map[uint32][]*TestFullNode, nodesPerShard) + cp := CreateCryptoParams(nodesPerShard, numMetaNodes, maxShards, numKeysOnEachNode) + keysMap := PubKeysMapFromNodesKeysMap(cp.NodesKeys) + validatorsMap := GenValidatorsFromPubKeys(keysMap, maxShards) + eligibleMap, _ := nodesCoordinator.NodesInfoToValidators(validatorsMap) + waitingMap := make(map[uint32][]nodesCoordinator.Validator) + connectableNodes := make(map[uint32][]Connectable, 0) + + startTime := time.Now().Unix() + testHasher := createHasher(consensusType) + + for shardID := range cp.NodesKeys { + for _, keysPair := range cp.NodesKeys[shardID] { + multiSigner, _ := multisig.NewBLSMultisig(&mclMultiSig.BlsMultiSigner{Hasher: testHasher}, cp.KeyGen) + multiSignerMock := createCustomMultiSignerMock(multiSigner) + + args := ArgsTestFullNode{ + ArgTestProcessorNode: &ArgTestProcessorNode{ + MaxShards: 2, + NodeShardId: 0, + TxSignPrivKeyShardId: 0, + WithSync: withSync, + EpochsConfig: &enableEpochsConfig, + NodeKeys: keysPair, + }, + ShardID: shardID, + ConsensusSize: consensusSize, + RoundTime: roundTime, + ConsensusType: consensusType, + EligibleMap: eligibleMap, + WaitingMap: waitingMap, + KeyGen: cp.KeyGen, + P2PKeyGen: cp.P2PKeyGen, + MultiSigner: multiSignerMock, + StartTime: startTime, + } + + tfn := NewTestFullNode(args) + nodes[shardID] = append(nodes[shardID], tfn) + connectableNodes[shardID] = append(connectableNodes[shardID], tfn) + } + } + + for shardID := range nodes { + ConnectNodes(connectableNodes[shardID]) + } + + return nodes +} + +// ArgsTestFullNode defines arguments for test full node +type ArgsTestFullNode struct { + *ArgTestProcessorNode + + ShardID uint32 + ConsensusSize int + RoundTime uint64 + ConsensusType string + EligibleMap map[uint32][]nodesCoordinator.Validator + WaitingMap map[uint32][]nodesCoordinator.Validator + KeyGen crypto.KeyGenerator + P2PKeyGen crypto.KeyGenerator + MultiSigner *cryptoMocks.MultisignerMock + StartTime int64 +} + +// TestFullNode defines the structure for testing node with full processing and consensus components +type TestFullNode struct { + *TestProcessorNode + + ShardCoordinator sharding.Coordinator + MultiSigner *cryptoMocks.MultisignerMock + GenesisTimeField time.Time +} + +// NewTestFullNode will create a new instance of full testing node +func NewTestFullNode(args ArgsTestFullNode) *TestFullNode { + tpn := newBaseTestProcessorNode(*args.ArgTestProcessorNode) + + shardCoordinator, _ := sharding.NewMultiShardCoordinator(maxShards, args.ShardID) + + tfn := &TestFullNode{ + TestProcessorNode: tpn, + ShardCoordinator: shardCoordinator, + MultiSigner: args.MultiSigner, + } + + tfn.initTestNodeWithArgs(*args.ArgTestProcessorNode, args) + + return tfn +} + +func (tfn *TestFullNode) initNodesCoordinator( + consensusSize int, + hasher hashing.Hasher, + epochStartRegistrationHandler notifier.EpochStartNotifier, + eligibleMap map[uint32][]nodesCoordinator.Validator, + waitingMap map[uint32][]nodesCoordinator.Validator, + pkBytes []byte, + cache storage.Cacher, +) { + argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(consensusSize), + MetachainConsensusGroupSize: uint32(consensusSize), + }, nil + }, + }, + Marshalizer: TestMarshalizer, + Hasher: hasher, + Shuffler: &shardingMocks.NodeShufflerMock{}, + EpochStartNotifier: epochStartRegistrationHandler, + BootStorer: CreateMemUnit(), + NbShards: maxShards, + EligibleNodes: eligibleMap, + WaitingNodes: waitingMap, + SelfPublicKey: pkBytes, + ConsensusGroupCache: cache, + ShuffledOutHandler: &chainShardingMocks.ShuffledOutHandlerStub{}, + ChanStopNode: endProcess.GetDummyEndProcessChannel(), + NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, + IsFullArchive: false, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, + ShardIDAsObserver: tfn.ShardCoordinator.SelfId(), + GenesisNodesSetupHandler: &genesisMocks.NodesSetupStub{}, + NodesCoordinatorRegistryFactory: &shardingMocks.NodesCoordinatorRegistryFactoryMock{}, + } + + tfn.NodesCoordinator, _ = nodesCoordinator.NewIndexHashedNodesCoordinator(argumentsNodesCoordinator) +} + +func (tpn *TestFullNode) initTestNodeWithArgs(args ArgTestProcessorNode, fullArgs ArgsTestFullNode) { + tpn.AppStatusHandler = args.AppStatusHandler + if check.IfNil(args.AppStatusHandler) { + tpn.AppStatusHandler = TestAppStatusHandler + } + + tpn.MainMessenger = CreateMessengerWithNoDiscovery() + + tpn.StatusMetrics = args.StatusMetrics + if check.IfNil(args.StatusMetrics) { + args.StatusMetrics = &testscommon.StatusMetricsStub{} + } + + tpn.initChainHandler() + tpn.initHeaderValidator() + tpn.initRoundHandler() + + syncer := ntp.NewSyncTime(ntp.NewNTPGoogleConfig(), nil) + syncer.StartSyncingTime() + tpn.GenesisTimeField = time.Unix(fullArgs.StartTime, 0) + + roundHandler, _ := round.NewRound( + tpn.GenesisTimeField, + syncer.CurrentTime(), + time.Millisecond*time.Duration(fullArgs.RoundTime), + syncer, + 0) + + tpn.NetworkShardingCollector = mock.NewNetworkShardingCollectorMock() + if check.IfNil(tpn.EpochNotifier) { + tpn.EpochStartNotifier = notifier.NewEpochStartSubscriptionHandler() + } + tpn.initStorage() + if check.IfNil(args.TrieStore) { + tpn.initAccountDBsWithPruningStorer() + } else { + tpn.initAccountDBs(args.TrieStore) + } + + economicsConfig := args.EconomicsConfig + if economicsConfig == nil { + economicsConfig = createDefaultEconomicsConfig() + } + + tpn.initEconomicsData(economicsConfig) + tpn.initRatingsData() + tpn.initRequestedItemsHandler() + tpn.initResolvers() + tpn.initRequesters() + tpn.initValidatorStatistics() + tpn.initGenesisBlocks(args) + tpn.initBlockTracker(roundHandler) + + gasMap := wasmConfig.MakeGasMapForTests() + defaults.FillGasMapInternal(gasMap, 1) + if args.GasScheduleMap != nil { + gasMap = args.GasScheduleMap + } + vmConfig := getDefaultVMConfig() + if args.VMConfig != nil { + vmConfig = args.VMConfig + } + tpn.initInnerProcessors(gasMap, vmConfig) + + if check.IfNil(args.TrieStore) { + var apiBlockchain data.ChainHandler + if tpn.ShardCoordinator.SelfId() == core.MetachainShardId { + apiBlockchain, _ = blockchain.NewMetaChain(statusHandlerMock.NewAppStatusHandlerMock()) + } else { + apiBlockchain, _ = blockchain.NewBlockChain(statusHandlerMock.NewAppStatusHandlerMock()) + } + argsNewScQueryService := smartContract.ArgsNewSCQueryService{ + VmContainer: tpn.VMContainer, + EconomicsFee: tpn.EconomicsData, + BlockChainHook: tpn.BlockchainHook, + MainBlockChain: tpn.BlockChain, + APIBlockChain: apiBlockchain, + WasmVMChangeLocker: tpn.WasmVMChangeLocker, + Bootstrapper: tpn.Bootstrapper, + AllowExternalQueriesChan: common.GetClosedUnbufferedChannel(), + HistoryRepository: tpn.HistoryRepository, + ShardCoordinator: tpn.ShardCoordinator, + StorageService: tpn.Storage, + Marshaller: TestMarshaller, + Hasher: TestHasher, + Uint64ByteSliceConverter: TestUint64Converter, + } + tpn.SCQueryService, _ = smartContract.NewSCQueryService(argsNewScQueryService) + } else { + tpn.createFullSCQueryService(gasMap, vmConfig) + } + + testHasher := createHasher(fullArgs.ConsensusType) + epochStartRegistrationHandler := notifier.NewEpochStartSubscriptionHandler() + pkBytes, _ := tpn.NodeKeys.MainKey.Pk.ToByteArray() + consensusCache, _ := cache.NewLRUCache(10000) + + tpn.initNodesCoordinator( + fullArgs.ConsensusSize, + testHasher, + epochStartRegistrationHandler, + fullArgs.EligibleMap, + fullArgs.WaitingMap, + pkBytes, + consensusCache, + ) + + tpn.BroadcastMessenger, _ = sposFactory.GetBroadcastMessenger( + TestMarshalizer, + TestHasher, + tpn.MainMessenger, + tpn.ShardCoordinator, + tpn.OwnAccount.PeerSigHandler, + tpn.DataPool.Headers(), + tpn.MainInterceptorsContainer, + &testscommon.AlarmSchedulerStub{}, + testscommon.NewKeysHandlerSingleSignerMock( + tpn.NodeKeys.MainKey.Sk, + tpn.MainMessenger.ID(), + ), + ) + + if args.WithSync { + tpn.initBootstrapper() + } + tpn.setGenesisBlock() + tpn.initNode(fullArgs, syncer, roundHandler) + tpn.addHandlersForCounters() + tpn.addGenesisBlocksIntoStorage() + + if args.GenesisFile != "" { + tpn.createHeartbeatWithHardforkTrigger() + } +} + +func (tpn *TestFullNode) setGenesisBlock() { + genesisBlock := tpn.GenesisBlocks[tpn.ShardCoordinator.SelfId()] + _ = tpn.BlockChain.SetGenesisHeader(genesisBlock) + hash, _ := core.CalculateHash(TestMarshalizer, TestHasher, genesisBlock) + tpn.BlockChain.SetGenesisHeaderHash(hash) + log.Info("set genesis", + "shard ID", tpn.ShardCoordinator.SelfId(), + "hash", hex.EncodeToString(hash), + ) +} + +func (tpn *TestFullNode) initChainHandler() { + if tpn.ShardCoordinator.SelfId() == core.MetachainShardId { + tpn.BlockChain = CreateMetaChain() + } else { + tpn.BlockChain = CreateShardChain() + } +} + +func (tpn *TestFullNode) initNode( + args ArgsTestFullNode, + syncer ntp.SyncTimer, + roundHandler consensus.RoundHandler, +) { + var err error + + statusCoreComponents := &testFactory.StatusCoreComponentsStub{ + StatusMetricsField: tpn.StatusMetrics, + AppStatusHandlerField: tpn.AppStatusHandler, + } + if tpn.EpochNotifier == nil { + tpn.EpochNotifier = forking.NewGenericEpochNotifier() + } + if tpn.EnableEpochsHandler == nil { + tpn.EnableEpochsHandler, _ = enablers.NewEnableEpochsHandler(CreateEnableEpochsConfig(), tpn.EpochNotifier) + } + + epochTrigger := tpn.createEpochStartTrigger(args.StartTime) + tpn.EpochStartTrigger = epochTrigger + + strPk := "" + if !check.IfNil(args.HardforkPk) { + buff, err := args.HardforkPk.ToByteArray() + log.LogIfError(err) + + strPk = hex.EncodeToString(buff) + } + _ = tpn.createHardforkTrigger(strPk) + + coreComponents := GetDefaultCoreComponents(tpn.EnableEpochsHandler, tpn.EpochNotifier) + coreComponents.SyncTimerField = syncer + coreComponents.RoundHandlerField = roundHandler + + coreComponents.InternalMarshalizerField = TestMarshalizer + coreComponents.VmMarshalizerField = TestVmMarshalizer + coreComponents.TxMarshalizerField = TestTxSignMarshalizer + coreComponents.HasherField = TestHasher + coreComponents.AddressPubKeyConverterField = TestAddressPubkeyConverter + coreComponents.ValidatorPubKeyConverterField = TestValidatorPubkeyConverter + coreComponents.ChainIdCalled = func() string { + return string(tpn.ChainID) + } + + coreComponents.GenesisTimeField = tpn.GenesisTimeField + coreComponents.GenesisNodesSetupField = &genesisMocks.NodesSetupStub{ + GetShardConsensusGroupSizeCalled: func() uint32 { + return uint32(args.ConsensusSize) + }, + GetMetaConsensusGroupSizeCalled: func() uint32 { + return uint32(args.ConsensusSize) + }, + } + coreComponents.MinTransactionVersionCalled = func() uint32 { + return tpn.MinTransactionVersion + } + coreComponents.TxVersionCheckField = versioning.NewTxVersionChecker(tpn.MinTransactionVersion) + hardforkPubKeyBytes, _ := coreComponents.ValidatorPubKeyConverterField.Decode(hardforkPubKey) + coreComponents.HardforkTriggerPubKeyField = hardforkPubKeyBytes + coreComponents.Uint64ByteSliceConverterField = TestUint64Converter + coreComponents.EconomicsDataField = tpn.EconomicsData + coreComponents.APIEconomicsHandler = tpn.EconomicsData + coreComponents.EnableEpochsHandlerField = tpn.EnableEpochsHandler + coreComponents.EpochNotifierField = tpn.EpochNotifier + coreComponents.RoundNotifierField = tpn.RoundNotifier + coreComponents.WasmVMChangeLockerInternal = tpn.WasmVMChangeLocker + coreComponents.EconomicsDataField = tpn.EconomicsData + + dataComponents := GetDefaultDataComponents() + dataComponents.BlockChain = tpn.BlockChain + dataComponents.DataPool = tpn.DataPool + dataComponents.Store = tpn.Storage + + bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator, tpn.EnableEpochsHandler) + + tpn.BlockBlackListHandler = cache.NewTimeCache(TimeSpanForBadHeaders) + tpn.ForkDetector = tpn.createForkDetector(args.StartTime, roundHandler) + + argsKeysHolder := keysManagement.ArgsManagedPeersHolder{ + KeyGenerator: args.KeyGen, + P2PKeyGenerator: args.P2PKeyGen, + MaxRoundsOfInactivity: 0, // 0 for main node, non-0 for backup node + PrefsConfig: config.Preferences{}, + P2PKeyConverter: p2pFactory.NewP2PKeyConverter(), + } + keysHolder, _ := keysManagement.NewManagedPeersHolder(argsKeysHolder) + + // adding provided handled keys + for _, key := range args.NodeKeys.HandledKeys { + skBytes, _ := key.Sk.ToByteArray() + _ = keysHolder.AddManagedPeer(skBytes) + } + + multiSigContainer := cryptoMocks.NewMultiSignerContainerMock(args.MultiSigner) + pubKey := tpn.NodeKeys.MainKey.Sk.GeneratePublic() + pubKeyBytes, _ := pubKey.ToByteArray() + pubKeyString := coreComponents.ValidatorPubKeyConverterField.SilentEncode(pubKeyBytes, log) + argsKeysHandler := keysManagement.ArgsKeysHandler{ + ManagedPeersHolder: keysHolder, + PrivateKey: tpn.NodeKeys.MainKey.Sk, + Pid: tpn.MainMessenger.ID(), + } + keysHandler, _ := keysManagement.NewKeysHandler(argsKeysHandler) + + signingHandlerArgs := cryptoFactory.ArgsSigningHandler{ + PubKeys: []string{pubKeyString}, + MultiSignerContainer: multiSigContainer, + KeyGenerator: args.KeyGen, + KeysHandler: keysHandler, + SingleSigner: TestSingleBlsSigner, + } + sigHandler, _ := cryptoFactory.NewSigningHandler(signingHandlerArgs) + + cryptoComponents := GetDefaultCryptoComponents() + cryptoComponents.PrivKey = tpn.NodeKeys.MainKey.Sk + cryptoComponents.PubKey = tpn.NodeKeys.MainKey.Pk + cryptoComponents.TxSig = tpn.OwnAccount.SingleSigner + cryptoComponents.BlockSig = tpn.OwnAccount.SingleSigner + cryptoComponents.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(tpn.MultiSigner) + cryptoComponents.BlKeyGen = tpn.OwnAccount.KeygenTxSign + cryptoComponents.TxKeyGen = TestKeyGenForAccounts + + peerSigCache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000}) + peerSigHandler, _ := peerSignatureHandler.NewPeerSignatureHandler(peerSigCache, TestSingleBlsSigner, args.KeyGen) + cryptoComponents.PeerSignHandler = peerSigHandler + cryptoComponents.SigHandler = sigHandler + cryptoComponents.KeysHandlerField = keysHandler + + tpn.initInterceptors(coreComponents, cryptoComponents, roundHandler, tpn.EnableEpochsHandler, tpn.Storage, epochTrigger) + + if args.WithSync { + tpn.initBlockProcessorWithSync(coreComponents, dataComponents, roundHandler) + } else { + tpn.initBlockProcessor(coreComponents, dataComponents, args, roundHandler) + } + + processComponents := GetDefaultProcessComponents() + processComponents.ForkDetect = tpn.ForkDetector + processComponents.BlockProcess = tpn.BlockProcessor + processComponents.ReqFinder = tpn.RequestersFinder + processComponents.HeaderIntegrVerif = tpn.HeaderIntegrityVerifier + processComponents.HeaderSigVerif = tpn.HeaderSigVerifier + processComponents.BlackListHdl = tpn.BlockBlackListHandler + processComponents.NodesCoord = tpn.NodesCoordinator + processComponents.ShardCoord = tpn.ShardCoordinator + processComponents.IntContainer = tpn.MainInterceptorsContainer + processComponents.FullArchiveIntContainer = tpn.FullArchiveInterceptorsContainer + processComponents.HistoryRepositoryInternal = tpn.HistoryRepository + processComponents.WhiteListHandlerInternal = tpn.WhiteListHandler + processComponents.WhiteListerVerifiedTxsInternal = tpn.WhiteListerVerifiedTxs + processComponents.TxsSenderHandlerField = createTxsSender(tpn.ShardCoordinator, tpn.MainMessenger) + processComponents.HardforkTriggerField = tpn.HardforkTrigger + processComponents.ScheduledTxsExecutionHandlerInternal = &testscommon.ScheduledTxsExecutionStub{} + processComponents.ProcessedMiniBlocksTrackerInternal = &testscommon.ProcessedMiniBlocksTrackerStub{} + processComponents.SentSignaturesTrackerInternal = &testscommon.SentSignatureTrackerStub{} + + processComponents.RoundHandlerField = roundHandler + processComponents.EpochNotifier = tpn.EpochStartNotifier + + stateComponents := GetDefaultStateComponents() + stateComponents.Accounts = tpn.AccntState + stateComponents.AccountsAPI = tpn.AccntState + + finalProvider, _ := blockInfoProviders.NewFinalBlockInfo(dataComponents.BlockChain) + finalAccountsApi, _ := state.NewAccountsDBApi(tpn.AccntState, finalProvider) + + currentProvider, _ := blockInfoProviders.NewCurrentBlockInfo(dataComponents.BlockChain) + currentAccountsApi, _ := state.NewAccountsDBApi(tpn.AccntState, currentProvider) + + historicalAccountsApi, _ := state.NewAccountsDBApiWithHistory(tpn.AccntState) + + argsAccountsRepo := state.ArgsAccountsRepository{ + FinalStateAccountsWrapper: finalAccountsApi, + CurrentStateAccountsWrapper: currentAccountsApi, + HistoricalStateAccountsWrapper: historicalAccountsApi, + } + stateComponents.AccountsRepo, _ = state.NewAccountsRepository(argsAccountsRepo) + + networkComponents := GetDefaultNetworkComponents() + networkComponents.Messenger = tpn.MainMessenger + networkComponents.FullArchiveNetworkMessengerField = tpn.FullArchiveMessenger + networkComponents.PeersRatingHandlerField = tpn.PeersRatingHandler + networkComponents.PeersRatingMonitorField = tpn.PeersRatingMonitor + networkComponents.InputAntiFlood = &mock.NilAntifloodHandler{} + networkComponents.PeerHonesty = &mock.PeerHonestyHandlerStub{} + + tpn.Node, err = node.NewNode( + node.WithAddressSignatureSize(64), + node.WithValidatorSignatureSize(48), + node.WithBootstrapComponents(bootstrapComponents), + node.WithCoreComponents(coreComponents), + node.WithStatusCoreComponents(statusCoreComponents), + node.WithDataComponents(dataComponents), + node.WithProcessComponents(processComponents), + node.WithCryptoComponents(cryptoComponents), + node.WithNetworkComponents(networkComponents), + node.WithStateComponents(stateComponents), + node.WithPeerDenialEvaluator(&mock.PeerDenialEvaluatorStub{}), + node.WithStatusCoreComponents(statusCoreComponents), + node.WithRoundDuration(args.RoundTime), + node.WithPublicKeySize(publicKeySize), + ) + log.LogIfError(err) + + err = nodeDebugFactory.CreateInterceptedDebugHandler( + tpn.Node, + tpn.MainInterceptorsContainer, + tpn.ResolversContainer, + tpn.RequestersFinder, + config.InterceptorResolverDebugConfig{ + Enabled: true, + CacheSize: 1000, + EnablePrint: true, + IntervalAutoPrintInSeconds: 1, + NumRequestsThreshold: 1, + NumResolveFailureThreshold: 1, + DebugLineExpiration: 1000, + }, + ) + log.LogIfError(err) +} + +func (tfn *TestFullNode) createForkDetector( + startTime int64, + roundHandler consensus.RoundHandler, +) process.ForkDetector { + var err error + var forkDetector process.ForkDetector + + if tfn.ShardCoordinator.SelfId() != core.MetachainShardId { + forkDetector, err = processSync.NewShardForkDetector( + roundHandler, + tfn.BlockBlackListHandler, + tfn.BlockTracker, + tfn.GenesisTimeField.Unix(), + tfn.EnableEpochsHandler, + tfn.DataPool.Proofs()) + } else { + forkDetector, err = processSync.NewMetaForkDetector( + roundHandler, + tfn.BlockBlackListHandler, + tfn.BlockTracker, + tfn.GenesisTimeField.Unix(), + tfn.EnableEpochsHandler, + tfn.DataPool.Proofs()) + } + if err != nil { + log.Error("error creating fork detector", "error", err) + return nil + } + + return forkDetector +} + +func (tfn *TestFullNode) createEpochStartTrigger(startTime int64) TestEpochStartTrigger { + var epochTrigger TestEpochStartTrigger + if tfn.ShardCoordinator.SelfId() == core.MetachainShardId { + argsNewMetaEpochStart := &metachain.ArgsNewMetaEpochStartTrigger{ + GenesisTime: tfn.GenesisTimeField, + EpochStartNotifier: notifier.NewEpochStartSubscriptionHandler(), + Settings: &config.EpochStartConfig{ + MinRoundsBetweenEpochs: 1, + RoundsPerEpoch: 1000, + }, + Epoch: 0, + Storage: createTestStore(), + Marshalizer: TestMarshalizer, + Hasher: TestHasher, + AppStatusHandler: &statusHandlerMock.AppStatusHandlerStub{}, + DataPool: tfn.DataPool, + } + epochStartTrigger, err := metachain.NewEpochStartTrigger(argsNewMetaEpochStart) + if err != nil { + fmt.Println(err.Error()) + } + epochTrigger = &metachain.TestTrigger{} + epochTrigger.SetTrigger(epochStartTrigger) + } else { + argsPeerMiniBlocksSyncer := shardchain.ArgPeerMiniBlockSyncer{ + MiniBlocksPool: tfn.DataPool.MiniBlocks(), + ValidatorsInfoPool: tfn.DataPool.ValidatorsInfo(), + RequestHandler: &testscommon.RequestHandlerStub{}, + } + peerMiniBlockSyncer, _ := shardchain.NewPeerMiniBlockSyncer(argsPeerMiniBlocksSyncer) + + argsShardEpochStart := &shardchain.ArgsShardEpochStartTrigger{ + Marshalizer: TestMarshalizer, + Hasher: TestHasher, + HeaderValidator: &mock.HeaderValidatorStub{}, + Uint64Converter: TestUint64Converter, + DataPool: tfn.DataPool, + Storage: tfn.Storage, + RequestHandler: &testscommon.RequestHandlerStub{}, + Epoch: 0, + Validity: 1, + Finality: 1, + EpochStartNotifier: notifier.NewEpochStartSubscriptionHandler(), + PeerMiniBlocksSyncer: peerMiniBlockSyncer, + RoundHandler: tfn.RoundHandler, + AppStatusHandler: &statusHandlerMock.AppStatusHandlerStub{}, + EnableEpochsHandler: tfn.EnableEpochsHandler, + } + epochStartTrigger, err := shardchain.NewEpochStartTrigger(argsShardEpochStart) + if err != nil { + fmt.Println(err.Error()) + } + epochTrigger = &shardchain.TestTrigger{} + epochTrigger.SetTrigger(epochStartTrigger) + } + + return epochTrigger +} + +func (tcn *TestFullNode) initInterceptors( + coreComponents process.CoreComponentsHolder, + cryptoComponents process.CryptoComponentsHolder, + roundHandler consensus.RoundHandler, + enableEpochsHandler common.EnableEpochsHandler, + storage dataRetriever.StorageService, + epochStartTrigger TestEpochStartTrigger, +) { + interceptorDataVerifierArgs := interceptorsFactory.InterceptedDataVerifierFactoryArgs{ + CacheSpan: time.Second * 10, + CacheExpiry: time.Second * 10, + } + + accountsAdapter := epochStartDisabled.NewAccountsAdapter() + + blockBlackListHandler := cache.NewTimeCache(TimeSpanForBadHeaders) + + genesisBlocks := make(map[uint32]data.HeaderHandler) + blockTracker := processMock.NewBlockTrackerMock(tcn.ShardCoordinator, genesisBlocks) + + whiteLstHandler, _ := disabledInterceptors.NewDisabledWhiteListDataVerifier() + + cacherVerifiedCfg := storageunit.CacheConfig{Capacity: 5000, Type: storageunit.LRUCache, Shards: 1} + cacheVerified, _ := storageunit.NewCache(cacherVerifiedCfg) + whiteListerVerifiedTxs, _ := interceptors.NewWhiteListDataVerifier(cacheVerified) + + interceptorContainerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ + CoreComponents: coreComponents, + CryptoComponents: cryptoComponents, + Accounts: accountsAdapter, + ShardCoordinator: tcn.ShardCoordinator, + NodesCoordinator: tcn.NodesCoordinator, + MainMessenger: tcn.MainMessenger, + FullArchiveMessenger: tcn.FullArchiveMessenger, + Store: storage, + DataPool: tcn.DataPool, + MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, + TxFeeHandler: &economicsmocks.EconomicsHandlerMock{}, + BlockBlackList: blockBlackListHandler, + HeaderSigVerifier: &consensusMocks.HeaderSigVerifierMock{}, + HeaderIntegrityVerifier: CreateHeaderIntegrityVerifier(), + ValidityAttester: blockTracker, + EpochStartTrigger: epochStartTrigger, + WhiteListHandler: whiteLstHandler, + WhiteListerVerifiedTxs: whiteListerVerifiedTxs, + AntifloodHandler: &mock.NilAntifloodHandler{}, + ArgumentsParser: smartContract.NewArgumentParser(), + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + SizeCheckDelta: sizeCheckDelta, + RequestHandler: &testscommon.RequestHandlerStub{}, + PeerSignatureHandler: &processMock.PeerSignatureHandlerStub{}, + SignaturesHandler: &processMock.SignaturesHandlerStub{}, + HeartbeatExpiryTimespanInSec: 30, + MainPeerShardMapper: mock.NewNetworkShardingCollectorMock(), + FullArchivePeerShardMapper: mock.NewNetworkShardingCollectorMock(), + HardforkTrigger: &testscommon.HardforkTriggerStub{}, + NodeOperationMode: common.NormalOperation, + InterceptedDataVerifierFactory: interceptorsFactory.NewInterceptedDataVerifierFactory(interceptorDataVerifierArgs), + } + if tcn.ShardCoordinator.SelfId() == core.MetachainShardId { + interceptorContainerFactory, err := interceptorscontainer.NewMetaInterceptorsContainerFactory(interceptorContainerFactoryArgs) + if err != nil { + fmt.Println(err.Error()) + } + + tcn.MainInterceptorsContainer, _, err = interceptorContainerFactory.Create() + if err != nil { + log.Debug("interceptor container factory Create", "error", err.Error()) + } + } else { + argsPeerMiniBlocksSyncer := shardchain.ArgPeerMiniBlockSyncer{ + MiniBlocksPool: tcn.DataPool.MiniBlocks(), + ValidatorsInfoPool: tcn.DataPool.ValidatorsInfo(), + RequestHandler: &testscommon.RequestHandlerStub{}, + } + peerMiniBlockSyncer, _ := shardchain.NewPeerMiniBlockSyncer(argsPeerMiniBlocksSyncer) + argsShardEpochStart := &shardchain.ArgsShardEpochStartTrigger{ + Marshalizer: TestMarshalizer, + Hasher: TestHasher, + HeaderValidator: &mock.HeaderValidatorStub{}, + Uint64Converter: TestUint64Converter, + DataPool: tcn.DataPool, + Storage: storage, + RequestHandler: &testscommon.RequestHandlerStub{}, + Epoch: 0, + Validity: 1, + Finality: 1, + EpochStartNotifier: notifier.NewEpochStartSubscriptionHandler(), + PeerMiniBlocksSyncer: peerMiniBlockSyncer, + RoundHandler: roundHandler, + AppStatusHandler: &statusHandlerMock.AppStatusHandlerStub{}, + EnableEpochsHandler: enableEpochsHandler, + } + _, _ = shardchain.NewEpochStartTrigger(argsShardEpochStart) + + interceptorContainerFactory, err := interceptorscontainer.NewShardInterceptorsContainerFactory(interceptorContainerFactoryArgs) + if err != nil { + fmt.Println(err.Error()) + } + + tcn.MainInterceptorsContainer, _, err = interceptorContainerFactory.Create() + if err != nil { + fmt.Println(err.Error()) + } + } +} + +func (tpn *TestFullNode) initBlockProcessor( + coreComponents *mock.CoreComponentsStub, + dataComponents *mock.DataComponentsStub, + args ArgsTestFullNode, + roundHandler consensus.RoundHandler, +) { + var err error + + accountsDb := make(map[state.AccountsDbIdentifier]state.AccountsAdapter) + accountsDb[state.UserAccountsState] = tpn.AccntState + accountsDb[state.PeerAccountsState] = tpn.PeerState + + if tpn.EpochNotifier == nil { + tpn.EpochNotifier = forking.NewGenericEpochNotifier() + } + if tpn.EnableEpochsHandler == nil { + tpn.EnableEpochsHandler, _ = enablers.NewEnableEpochsHandler(CreateEnableEpochsConfig(), tpn.EpochNotifier) + } + + bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator, tpn.EnableEpochsHandler) + bootstrapComponents.HdrIntegrityVerifier = tpn.HeaderIntegrityVerifier + + statusComponents := GetDefaultStatusComponents() + + statusCoreComponents := &testFactory.StatusCoreComponentsStub{ + AppStatusHandlerField: &statusHandlerMock.AppStatusHandlerStub{}, + } + + argumentsBase := block.ArgBaseProcessor{ + CoreComponents: coreComponents, + DataComponents: dataComponents, + BootstrapComponents: bootstrapComponents, + StatusComponents: statusComponents, + StatusCoreComponents: statusCoreComponents, + Config: config.Config{}, + AccountsDB: accountsDb, + ForkDetector: tpn.ForkDetector, + NodesCoordinator: tpn.NodesCoordinator, + FeeHandler: tpn.FeeAccumulator, + RequestHandler: tpn.RequestHandler, + BlockChainHook: tpn.BlockchainHook, + HeaderValidator: tpn.HeaderValidator, + BootStorer: &mock.BoostrapStorerMock{ + PutCalled: func(round int64, bootData bootstrapStorage.BootstrapData) error { + return nil + }, + }, + BlockTracker: tpn.BlockTracker, + BlockSizeThrottler: TestBlockSizeThrottler, + HistoryRepository: tpn.HistoryRepository, + GasHandler: tpn.GasHandler, + ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, + ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, + ReceiptsRepository: &testscommon.ReceiptsRepositoryStub{}, + OutportDataProvider: &outport.OutportDataProviderStub{}, + BlockProcessingCutoffHandler: &testscommon.BlockProcessingCutoffStub{}, + ManagedPeersHolder: &testscommon.ManagedPeersHolderStub{}, + SentSignaturesTracker: &testscommon.SentSignatureTrackerStub{}, + } + + if check.IfNil(tpn.EpochStartNotifier) { + tpn.EpochStartNotifier = notifier.NewEpochStartSubscriptionHandler() + } + + if tpn.ShardCoordinator.SelfId() == core.MetachainShardId { + argumentsBase.EpochStartTrigger = tpn.EpochStartTrigger + argumentsBase.TxCoordinator = tpn.TxCoordinator + + argsStakingToPeer := scToProtocol.ArgStakingToPeer{ + PubkeyConv: TestValidatorPubkeyConverter, + Hasher: TestHasher, + Marshalizer: TestMarshalizer, + PeerState: tpn.PeerState, + BaseState: tpn.AccntState, + ArgParser: tpn.ArgsParser, + CurrTxs: tpn.DataPool.CurrentBlockTxs(), + RatingsData: tpn.RatingsData, + EnableEpochsHandler: tpn.EnableEpochsHandler, + } + scToProtocolInstance, _ := scToProtocol.NewStakingToPeer(argsStakingToPeer) + + argsEpochStartData := metachain.ArgsNewEpochStartData{ + Marshalizer: TestMarshalizer, + Hasher: TestHasher, + Store: tpn.Storage, + DataPool: tpn.DataPool, + BlockTracker: tpn.BlockTracker, + ShardCoordinator: tpn.ShardCoordinator, + EpochStartTrigger: tpn.EpochStartTrigger, + RequestHandler: tpn.RequestHandler, + EnableEpochsHandler: tpn.EnableEpochsHandler, + } + epochStartDataCreator, _ := metachain.NewEpochStartData(argsEpochStartData) + + economicsDataProvider := metachain.NewEpochEconomicsStatistics() + argsEpochEconomics := metachain.ArgsNewEpochEconomics{ + Marshalizer: TestMarshalizer, + Hasher: TestHasher, + Store: tpn.Storage, + ShardCoordinator: tpn.ShardCoordinator, + RewardsHandler: tpn.EconomicsData, + RoundTime: roundHandler, + GenesisTotalSupply: tpn.EconomicsData.GenesisTotalSupply(), + EconomicsDataNotified: economicsDataProvider, + StakingV2EnableEpoch: tpn.EnableEpochs.StakingV2EnableEpoch, + } + epochEconomics, _ := metachain.NewEndOfEpochEconomicsDataCreator(argsEpochEconomics) + + systemVM, _ := mock.NewOneSCExecutorMockVM(tpn.BlockchainHook, TestHasher) + + argsStakingDataProvider := metachain.StakingDataProviderArgs{ + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), + SystemVM: systemVM, + MinNodePrice: "1000", + } + stakingDataProvider, errRsp := metachain.NewStakingDataProvider(argsStakingDataProvider) + if errRsp != nil { + log.Error("initBlockProcessor NewRewardsStakingProvider", "error", errRsp) + } + + rewardsStorage, _ := tpn.Storage.GetStorer(dataRetriever.RewardTransactionUnit) + miniBlockStorage, _ := tpn.Storage.GetStorer(dataRetriever.MiniBlockUnit) + argsEpochRewards := metachain.RewardsCreatorProxyArgs{ + BaseRewardsCreatorArgs: metachain.BaseRewardsCreatorArgs{ + ShardCoordinator: tpn.ShardCoordinator, + PubkeyConverter: TestAddressPubkeyConverter, + RewardsStorage: rewardsStorage, + MiniBlockStorage: miniBlockStorage, + Hasher: TestHasher, + Marshalizer: TestMarshalizer, + DataPool: tpn.DataPool, + NodesConfigProvider: tpn.NodesCoordinator, + UserAccountsDB: tpn.AccntState, + EnableEpochsHandler: tpn.EnableEpochsHandler, + ExecutionOrderHandler: tpn.TxExecutionOrderHandler, + RewardsHandler: tpn.EconomicsData, + }, + StakingDataProvider: stakingDataProvider, + EconomicsDataProvider: economicsDataProvider, + } + epochStartRewards, err := metachain.NewRewardsCreatorProxy(argsEpochRewards) + if err != nil { + log.Error("error creating rewards proxy", "error", err) + } + + validatorInfoStorage, _ := tpn.Storage.GetStorer(dataRetriever.UnsignedTransactionUnit) + argsEpochValidatorInfo := metachain.ArgsNewValidatorInfoCreator{ + ShardCoordinator: tpn.ShardCoordinator, + ValidatorInfoStorage: validatorInfoStorage, + MiniBlockStorage: miniBlockStorage, + Hasher: TestHasher, + Marshalizer: TestMarshalizer, + DataPool: tpn.DataPool, + EnableEpochsHandler: tpn.EnableEpochsHandler, + } + epochStartValidatorInfo, _ := metachain.NewValidatorInfoCreator(argsEpochValidatorInfo) + + maxNodesChangeConfigProvider, _ := notifier.NewNodesConfigProvider( + tpn.EpochNotifier, + nil, + ) + auctionCfg := config.SoftAuctionConfig{ + TopUpStep: "10", + MinTopUp: "1", + MaxTopUp: "32000000", + MaxNumberOfIterations: 100000, + } + ald, _ := metachain.NewAuctionListDisplayer(metachain.ArgsAuctionListDisplayer{ + TableDisplayHandler: metachain.NewTableDisplayer(), + ValidatorPubKeyConverter: &testscommon.PubkeyConverterMock{}, + AddressPubKeyConverter: &testscommon.PubkeyConverterMock{}, + AuctionConfig: auctionCfg, + }) + + argsAuctionListSelector := metachain.AuctionListSelectorArgs{ + ShardCoordinator: tpn.ShardCoordinator, + StakingDataProvider: stakingDataProvider, + MaxNodesChangeConfigProvider: maxNodesChangeConfigProvider, + AuctionListDisplayHandler: ald, + SoftAuctionConfig: auctionCfg, + } + auctionListSelector, _ := metachain.NewAuctionListSelector(argsAuctionListSelector) + + argsEpochSystemSC := metachain.ArgsNewEpochStartSystemSCProcessing{ + SystemVM: systemVM, + UserAccountsDB: tpn.AccntState, + PeerAccountsDB: tpn.PeerState, + Marshalizer: TestMarshalizer, + StartRating: tpn.RatingsData.StartRating(), + ValidatorInfoCreator: tpn.ValidatorStatisticsProcessor, + EndOfEpochCallerAddress: vm.EndOfEpochAddress, + StakingSCAddress: vm.StakingSCAddress, + ChanceComputer: tpn.NodesCoordinator, + EpochNotifier: tpn.EpochNotifier, + GenesisNodesConfig: tpn.NodesSetup, + StakingDataProvider: stakingDataProvider, + NodesConfigProvider: tpn.NodesCoordinator, + ShardCoordinator: tpn.ShardCoordinator, + ESDTOwnerAddressBytes: vm.EndOfEpochAddress, + EnableEpochsHandler: tpn.EnableEpochsHandler, + AuctionListSelector: auctionListSelector, + MaxNodesChangeConfigProvider: maxNodesChangeConfigProvider, + } + epochStartSystemSCProcessor, _ := metachain.NewSystemSCProcessor(argsEpochSystemSC) + tpn.EpochStartSystemSCProcessor = epochStartSystemSCProcessor + + arguments := block.ArgMetaProcessor{ + ArgBaseProcessor: argumentsBase, + SCToProtocol: scToProtocolInstance, + PendingMiniBlocksHandler: &mock.PendingMiniBlocksHandlerStub{}, + EpochEconomics: epochEconomics, + EpochStartDataCreator: epochStartDataCreator, + EpochRewardsCreator: epochStartRewards, + EpochValidatorInfoCreator: epochStartValidatorInfo, + ValidatorStatisticsProcessor: tpn.ValidatorStatisticsProcessor, + EpochSystemSCProcessor: epochStartSystemSCProcessor, + } + + tpn.BlockProcessor, err = block.NewMetaProcessor(arguments) + if err != nil { + log.Error("error creating meta blockprocessor", "error", err) + } + } else { + argumentsBase.EpochStartTrigger = tpn.EpochStartTrigger + argumentsBase.BlockChainHook = tpn.BlockchainHook + argumentsBase.TxCoordinator = tpn.TxCoordinator + argumentsBase.ScheduledTxsExecutionHandler = &testscommon.ScheduledTxsExecutionStub{} + + arguments := block.ArgShardProcessor{ + ArgBaseProcessor: argumentsBase, + } + + tpn.BlockProcessor, err = block.NewShardProcessor(arguments) + if err != nil { + log.Error("error creating shard blockprocessor", "error", err) + } + } + +} + +func (tpn *TestFullNode) initBlockProcessorWithSync( + coreComponents *mock.CoreComponentsStub, + dataComponents *mock.DataComponentsStub, + roundHandler consensus.RoundHandler, +) { + var err error + + accountsDb := make(map[state.AccountsDbIdentifier]state.AccountsAdapter) + accountsDb[state.UserAccountsState] = tpn.AccntState + accountsDb[state.PeerAccountsState] = tpn.PeerState + + if tpn.EpochNotifier == nil { + tpn.EpochNotifier = forking.NewGenericEpochNotifier() + } + if tpn.EnableEpochsHandler == nil { + tpn.EnableEpochsHandler, _ = enablers.NewEnableEpochsHandler(CreateEnableEpochsConfig(), tpn.EpochNotifier) + } + + bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator, tpn.EnableEpochsHandler) + bootstrapComponents.HdrIntegrityVerifier = tpn.HeaderIntegrityVerifier + + statusComponents := GetDefaultStatusComponents() + + statusCoreComponents := &factory.StatusCoreComponentsStub{ + AppStatusHandlerField: &statusHandlerMock.AppStatusHandlerStub{}, + } + + argumentsBase := block.ArgBaseProcessor{ + CoreComponents: coreComponents, + DataComponents: dataComponents, + BootstrapComponents: bootstrapComponents, + StatusComponents: statusComponents, + StatusCoreComponents: statusCoreComponents, + Config: config.Config{}, + AccountsDB: accountsDb, + ForkDetector: nil, + NodesCoordinator: tpn.NodesCoordinator, + FeeHandler: tpn.FeeAccumulator, + RequestHandler: tpn.RequestHandler, + BlockChainHook: &testscommon.BlockChainHookStub{}, + EpochStartTrigger: &mock.EpochStartTriggerStub{}, + HeaderValidator: tpn.HeaderValidator, + BootStorer: &mock.BoostrapStorerMock{ + PutCalled: func(round int64, bootData bootstrapStorage.BootstrapData) error { + return nil + }, + }, + BlockTracker: tpn.BlockTracker, + BlockSizeThrottler: TestBlockSizeThrottler, + HistoryRepository: tpn.HistoryRepository, + GasHandler: tpn.GasHandler, + ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, + ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, + ReceiptsRepository: &testscommon.ReceiptsRepositoryStub{}, + OutportDataProvider: &outport.OutportDataProviderStub{}, + BlockProcessingCutoffHandler: &testscommon.BlockProcessingCutoffStub{}, + ManagedPeersHolder: &testscommon.ManagedPeersHolderStub{}, + SentSignaturesTracker: &testscommon.SentSignatureTrackerStub{}, + } + + if tpn.ShardCoordinator.SelfId() == core.MetachainShardId { + argumentsBase.ForkDetector = tpn.ForkDetector + argumentsBase.TxCoordinator = &mock.TransactionCoordinatorMock{} + arguments := block.ArgMetaProcessor{ + ArgBaseProcessor: argumentsBase, + SCToProtocol: &mock.SCToProtocolStub{}, + PendingMiniBlocksHandler: &mock.PendingMiniBlocksHandlerStub{}, + EpochStartDataCreator: &mock.EpochStartDataCreatorStub{}, + EpochEconomics: &mock.EpochEconomicsStub{}, + EpochRewardsCreator: &testscommon.RewardsCreatorStub{}, + EpochValidatorInfoCreator: &testscommon.EpochValidatorInfoCreatorStub{}, + ValidatorStatisticsProcessor: &testscommon.ValidatorStatisticsProcessorStub{ + UpdatePeerStateCalled: func(header data.MetaHeaderHandler) ([]byte, error) { + return []byte("validator stats root hash"), nil + }, + }, + EpochSystemSCProcessor: &testscommon.EpochStartSystemSCStub{}, + } + + tpn.BlockProcessor, err = block.NewMetaProcessor(arguments) + } else { + argumentsBase.ForkDetector = tpn.ForkDetector + argumentsBase.BlockChainHook = tpn.BlockchainHook + argumentsBase.TxCoordinator = tpn.TxCoordinator + argumentsBase.ScheduledTxsExecutionHandler = &testscommon.ScheduledTxsExecutionStub{} + arguments := block.ArgShardProcessor{ + ArgBaseProcessor: argumentsBase, + } + + tpn.BlockProcessor, err = block.NewShardProcessor(arguments) + } + + if err != nil { + panic(fmt.Sprintf("Error creating blockprocessor: %s", err.Error())) + } +} + +func (tpn *TestFullNode) initBlockTracker( + roundHandler consensus.RoundHandler, +) { + argBaseTracker := track.ArgBaseTracker{ + Hasher: TestHasher, + HeaderValidator: tpn.HeaderValidator, + Marshalizer: TestMarshalizer, + RequestHandler: tpn.RequestHandler, + RoundHandler: roundHandler, + ShardCoordinator: tpn.ShardCoordinator, + Store: tpn.Storage, + StartHeaders: tpn.GenesisBlocks, + PoolsHolder: tpn.DataPool, + WhitelistHandler: tpn.WhiteListHandler, + FeeHandler: tpn.EconomicsData, + EnableEpochsHandler: tpn.EnableEpochsHandler, + ProofsPool: tpn.DataPool.Proofs(), + EpochChangeGracePeriodHandler: tpn.EpochChangeGracePeriodHandler, + } + + var err error + if tpn.ShardCoordinator.SelfId() != core.MetachainShardId { + arguments := track.ArgShardTracker{ + ArgBaseTracker: argBaseTracker, + } + + tpn.BlockTracker, err = track.NewShardBlockTrack(arguments) + if err != nil { + log.Error("NewShardBlockTrack", "error", err) + } + } else { + arguments := track.ArgMetaTracker{ + ArgBaseTracker: argBaseTracker, + } + + tpn.BlockTracker, err = track.NewMetaBlockTrack(arguments) + if err != nil { + log.Error("NewMetaBlockTrack", "error", err) + } + } +} diff --git a/integrationTests/testHeartbeatNode.go b/integrationTests/testHeartbeatNode.go index 1ba488b9e12..23b13c40e3d 100644 --- a/integrationTests/testHeartbeatNode.go +++ b/integrationTests/testHeartbeatNode.go @@ -21,6 +21,8 @@ import ( "github.com/multiversx/mx-chain-crypto-go/signing/mcl" "github.com/multiversx/mx-chain-crypto-go/signing/mcl/singlesig" "github.com/multiversx/mx-chain-crypto-go/signing/secp256k1" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -48,6 +50,7 @@ import ( "github.com/multiversx/mx-chain-go/storage/cache" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" @@ -58,7 +61,6 @@ import ( trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" vic "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" "github.com/multiversx/mx-chain-go/update" - "github.com/stretchr/testify/require" ) // constants used for the hearbeat node & generated messages @@ -349,8 +351,14 @@ func CreateNodesWithTestHeartbeatNode( suCache, _ := storageunit.NewCache(cacherCfg) for shardId, validatorList := range validatorsMap { argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ - ShardConsensusGroupSize: shardConsensusGroupSize, - MetaConsensusGroupSize: metaConsensusGroupSize, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(shardConsensusGroupSize), + MetachainConsensusGroupSize: uint32(metaConsensusGroupSize), + }, nil + }, + }, Marshalizer: TestMarshalizer, Hasher: TestHasher, ShardIDAsObserver: shardId, @@ -397,8 +405,14 @@ func CreateNodesWithTestHeartbeatNode( } argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ - ShardConsensusGroupSize: shardConsensusGroupSize, - MetaConsensusGroupSize: metaConsensusGroupSize, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(shardConsensusGroupSize), + MetachainConsensusGroupSize: uint32(metaConsensusGroupSize), + }, nil + }, + }, Marshalizer: TestMarshalizer, Hasher: TestHasher, ShardIDAsObserver: shardId, @@ -554,6 +568,11 @@ func (thn *TestHeartbeatNode) initResolversAndRequesters() { FullArchivePreferredPeersHolder: &p2pmocks.PeersHolderStub{}, PeersRatingHandler: &p2pmocks.PeersRatingHandlerStub{}, SizeCheckDelta: 0, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + }, } if thn.ShardCoordinator.SelfId() == core.MetachainShardId { @@ -694,6 +713,7 @@ func (thn *TestHeartbeatNode) initMultiDataInterceptor(topic string, dataFactory interceptors.ArgMultiDataInterceptor{ Topic: topic, Marshalizer: TestMarshalizer, + Hasher: TestHasher, DataFactory: dataFactory, Processor: processor, Throttler: TestThrottler, @@ -703,8 +723,9 @@ func (thn *TestHeartbeatNode) initMultiDataInterceptor(topic string, dataFactory return true }, }, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - CurrentPeerId: thn.MainMessenger.ID(), + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + CurrentPeerId: thn.MainMessenger.ID(), + InterceptedDataVerifier: &processMock.InterceptedDataVerifierMock{}, }, ) @@ -726,8 +747,9 @@ func (thn *TestHeartbeatNode) initSingleDataInterceptor(topic string, dataFactor return true }, }, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - CurrentPeerId: thn.MainMessenger.ID(), + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + CurrentPeerId: thn.MainMessenger.ID(), + InterceptedDataVerifier: &processMock.InterceptedDataVerifierMock{}, }, ) diff --git a/integrationTests/testInitializer.go b/integrationTests/testInitializer.go index f98ef2a5288..1a799d8eeb6 100644 --- a/integrationTests/testInitializer.go +++ b/integrationTests/testInitializer.go @@ -29,7 +29,14 @@ import ( "github.com/multiversx/mx-chain-crypto-go/signing/ed25519" "github.com/multiversx/mx-chain-crypto-go/signing/mcl" "github.com/multiversx/mx-chain-crypto-go/signing/secp256k1" + logger "github.com/multiversx/mx-chain-logger-go" + wasmConfig "github.com/multiversx/mx-chain-vm-go/config" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/common/statistics" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -79,10 +86,6 @@ import ( "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" - logger "github.com/multiversx/mx-chain-logger-go" - wasmConfig "github.com/multiversx/mx-chain-vm-go/config" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" ) // StepDelay is used so that transactions can disseminate properly @@ -408,6 +411,7 @@ func CreateStore(numOfShards uint32) dataRetriever.StorageService { store.AddStorer(dataRetriever.StatusMetricsUnit, CreateMemUnit()) store.AddStorer(dataRetriever.ReceiptsUnit, CreateMemUnit()) store.AddStorer(dataRetriever.ScheduledSCRsUnit, CreateMemUnit()) + store.AddStorer(dataRetriever.ProofsUnit, CreateMemUnit()) for i := uint32(0); i < numOfShards; i++ { hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(i) @@ -651,7 +655,9 @@ func CreateFullGenesisBlocks( gasSchedule := wasmConfig.MakeGasMapForTests() defaults.FillGasMapInternal(gasSchedule, 1) - coreComponents := GetDefaultCoreComponents(enableEpochsConfig) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(enableEpochsConfig, genericEpochNotifier) + coreComponents := GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.InternalMarshalizerField = TestMarshalizer coreComponents.TxMarshalizerField = TestTxSignMarshalizer coreComponents.HasherField = TestHasher @@ -776,7 +782,9 @@ func CreateGenesisMetaBlock( gasSchedule := wasmConfig.MakeGasMapForTests() defaults.FillGasMapInternal(gasSchedule, 1) - coreComponents := GetDefaultCoreComponents(enableEpochsConfig) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(enableEpochsConfig, genericEpochNotifier) + coreComponents := GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.InternalMarshalizerField = marshalizer coreComponents.HasherField = hasher coreComponents.Uint64ByteSliceConverterField = uint64Converter @@ -1148,21 +1156,53 @@ func IncrementAndPrintRound(round uint64) uint64 { } // ProposeBlock proposes a block for every shard -func ProposeBlock(nodes []*TestProcessorNode, idxProposers []int, round uint64, nonce uint64) { +func ProposeBlock(nodes []*TestProcessorNode, leaders []*TestProcessorNode, round uint64, nonce uint64) { log.Info("All shards propose blocks...") stepDelayAdjustment := StepDelay * time.Duration(1+len(nodes)/3) - for idx, n := range nodes { - if !IsIntInSlice(idx, idxProposers) { - continue - } + for _, n := range leaders { + body, header, _ := n.ProposeBlock(round, nonce) + n.WhiteListBody(nodes, body) + pk := n.NodeKeys.MainKey.Pk + n.BroadcastBlock(body, header, pk) + + _ = addProofIfNeeded(n, header) + n.CommitBlock(body, header) + } + + log.Info("Delaying for disseminating headers and miniblocks...") + time.Sleep(stepDelayAdjustment) + log.Info("Proposed block\n" + MakeDisplayTable(nodes)) +} + +// ProposeBlockWithProof proposes a block for every shard with custom handling for equivalent proof +func ProposeBlockWithProof( + nodes []*TestProcessorNode, + leaders []*TestProcessorNode, + round uint64, + nonce uint64, +) { + log.Info("All shards propose blocks with proof...") + stepDelayAdjustment := StepDelay * time.Duration(1+len(nodes)/3) + + for _, n := range leaders { body, header, _ := n.ProposeBlock(round, nonce) n.WhiteListBody(nodes, body) pk := n.NodeKeys.MainKey.Pk n.BroadcastBlock(body, header, pk) + + proof := addProofIfNeeded(n, header) n.CommitBlock(body, header) + + time.Sleep(SyncDelay) + + // cleanup proof from pool before broadcasting so that the interceptor will propagate the proof to the other nodes + _ = n.Node.GetDataComponents().Datapool().Proofs().CleanupProofsBehindNonce(n.ShardCoordinator.SelfId(), nonce+4) // default cleanup delta is 3 + + n.BroadcastProof(proof, pk) + } log.Info("Delaying for disseminating headers and miniblocks...") @@ -1170,17 +1210,40 @@ func ProposeBlock(nodes []*TestProcessorNode, idxProposers []int, round uint64, log.Info("Proposed block\n" + MakeDisplayTable(nodes)) } +func addProofIfNeeded(node *TestProcessorNode, header data.HeaderHandler) data.HeaderProofHandler { + coreComp := node.Node.GetCoreComponents() + if !common.IsProofsFlagEnabledForHeader(coreComp.EnableEpochsHandler(), header) { + return nil + } + + hash, _ := core.CalculateHash(coreComp.InternalMarshalizer(), coreComp.Hasher(), header) + proof := &dataBlock.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: hash, + HeaderEpoch: header.GetEpoch(), + HeaderNonce: header.GetNonce(), + HeaderShardId: header.GetShardID(), + HeaderRound: header.GetRound(), + IsStartOfEpoch: header.IsStartOfEpochBlock(), + } + + node.Node.GetDataComponents().Datapool().Proofs().AddProof(proof) + + return proof +} + // SyncBlock synchronizes the proposed block in all the other shard nodes func SyncBlock( t *testing.T, nodes []*TestProcessorNode, - idxProposers []int, + leaders []*TestProcessorNode, round uint64, ) { log.Info("All other shard nodes sync the proposed block...") - for idx, n := range nodes { - if IsIntInSlice(idx, idxProposers) { + for _, n := range nodes { + if IsNodeInSlice(n, leaders) { continue } @@ -1196,10 +1259,9 @@ func SyncBlock( log.Info("Synchronized block\n" + MakeDisplayTable(nodes)) } -// IsIntInSlice returns true if idx is found on any position in the provided slice -func IsIntInSlice(idx int, slice []int) bool { +func IsNodeInSlice(node *TestProcessorNode, slice []*TestProcessorNode) bool { for _, value := range slice { - if value == idx { + if value == node { return true } } @@ -2194,7 +2256,9 @@ func generateValidTx( _ = accnts.SaveAccount(acc) _, _ = accnts.Commit() - coreComponents := GetDefaultCoreComponents(CreateEnableEpochsConfig()) + genericEpochNotifier := forking.NewGenericEpochNotifier() + enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(CreateEnableEpochsConfig(), genericEpochNotifier) + coreComponents := GetDefaultCoreComponents(enableEpochsHandler, genericEpochNotifier) coreComponents.InternalMarshalizerField = TestMarshalizer coreComponents.TxMarshalizerField = TestTxSignMarshalizer coreComponents.VmMarshalizerField = TestMarshalizer @@ -2246,14 +2310,14 @@ func generateValidTx( func ProposeAndSyncOneBlock( t *testing.T, nodes []*TestProcessorNode, - idxProposers []int, + leaders []*TestProcessorNode, round uint64, nonce uint64, ) (uint64, uint64) { UpdateRound(nodes, round) - ProposeBlock(nodes, idxProposers, round, nonce) - SyncBlock(t, nodes, idxProposers, round) + ProposeBlock(nodes, leaders, round, nonce) + SyncBlock(t, nodes, leaders, round) round = IncrementAndPrintRound(round) nonce++ @@ -2424,7 +2488,7 @@ func BootstrapDelay() { func SetupSyncNodesOneShardAndMeta( numNodesPerShard int, numNodesMeta int, -) ([]*TestProcessorNode, []int) { +) ([]*TestProcessorNode, []*TestProcessorNode) { maxShardsLocal := uint32(1) shardId := uint32(0) @@ -2441,7 +2505,7 @@ func SetupSyncNodesOneShardAndMeta( nodes = append(nodes, shardNode) connectableNodes = append(connectableNodes, shardNode) } - idxProposerShard0 := 0 + leaderShard0 := nodes[0] for i := 0; i < numNodesMeta; i++ { metaNode := NewTestProcessorNode(ArgTestProcessorNode{ @@ -2453,13 +2517,13 @@ func SetupSyncNodesOneShardAndMeta( nodes = append(nodes, metaNode) connectableNodes = append(connectableNodes, metaNode) } - idxProposerMeta := len(nodes) - 1 + leaderMeta := nodes[len(nodes)-1] - idxProposers := []int{idxProposerShard0, idxProposerMeta} + leaders := []*TestProcessorNode{leaderShard0, leaderMeta} ConnectNodes(connectableNodes) - return nodes, idxProposers + return nodes, leaders } // StartSyncingBlocks starts the syncing process of all the nodes @@ -2541,14 +2605,14 @@ func UpdateRound(nodes []*TestProcessorNode, round uint64) { func ProposeBlocks( nodes []*TestProcessorNode, round *uint64, - idxProposers []int, + leaders []*TestProcessorNode, nonces []*uint64, numOfRounds int, ) { for i := 0; i < numOfRounds; i++ { crtRound := atomic.LoadUint64(round) - proposeBlocks(nodes, idxProposers, nonces, crtRound) + proposeBlocks(nodes, leaders, nonces, crtRound) time.Sleep(SyncDelay) @@ -2569,20 +2633,20 @@ func IncrementNonces(nonces []*uint64) { func proposeBlocks( nodes []*TestProcessorNode, - idxProposers []int, + leaders []*TestProcessorNode, nonces []*uint64, crtRound uint64, ) { - for idx, proposer := range idxProposers { + for idx, proposer := range leaders { crtNonce := atomic.LoadUint64(nonces[idx]) - ProposeBlock(nodes, []int{proposer}, crtRound, crtNonce) + ProposeBlock(nodes, []*TestProcessorNode{proposer}, crtRound, crtNonce) } } // WaitOperationToBeDone - -func WaitOperationToBeDone(t *testing.T, nodes []*TestProcessorNode, nrOfRounds int, nonce uint64, round uint64, idxProposers []int) (uint64, uint64) { +func WaitOperationToBeDone(t *testing.T, leaders []*TestProcessorNode, nodes []*TestProcessorNode, nrOfRounds int, nonce uint64, round uint64) (uint64, uint64) { for i := 0; i < nrOfRounds; i++ { - round, nonce = ProposeAndSyncOneBlock(t, nodes, idxProposers, round, nonce) + round, nonce = ProposeAndSyncOneBlock(t, nodes, leaders, round, nonce) } return nonce, round diff --git a/integrationTests/testNetwork.go b/integrationTests/testNetwork.go index a08b3aa85c7..c64651fab3f 100644 --- a/integrationTests/testNetwork.go +++ b/integrationTests/testNetwork.go @@ -8,12 +8,13 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/transaction" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/require" ) // ShardIdentifier is the numeric index of a shard @@ -44,7 +45,7 @@ type TestNetwork struct { NodesSharded NodesByShardMap Wallets []*TestWalletAccount DeploymentAddress Address - Proposers []int + Proposers []*TestProcessorNode Round uint64 Nonce uint64 T *testing.T @@ -119,11 +120,11 @@ func (net *TestNetwork) Step() { func (net *TestNetwork) Steps(steps int) { net.Nonce, net.Round = WaitOperationToBeDone( net.T, + net.Proposers, net.Nodes, steps, net.Nonce, - net.Round, - net.Proposers) + net.Round) } // Close shuts down the test network. @@ -421,6 +422,7 @@ func (net *TestNetwork) createNodes() { StakingV2EnableEpoch: UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: UnreachableEpoch, + AndromedaEnableEpoch: UnreachableEpoch, } net.Nodes = CreateNodesWithEnableEpochs( @@ -432,11 +434,11 @@ func (net *TestNetwork) createNodes() { } func (net *TestNetwork) indexProposers() { - net.Proposers = make([]int, net.NumShards+1) + net.Proposers = make([]*TestProcessorNode, net.NumShards+1) for i := 0; i < net.NumShards; i++ { - net.Proposers[i] = i * net.NodesPerShard + net.Proposers[i] = net.Nodes[i*net.NodesPerShard] } - net.Proposers[net.NumShards] = net.NumShards * net.NodesPerShard + net.Proposers[net.NumShards] = net.Nodes[net.NumShards*net.NodesPerShard] } func (net *TestNetwork) mapNodesByShard() { diff --git a/integrationTests/testProcessorNode.go b/integrationTests/testProcessorNode.go index a2b8598bd06..e76a2dedc2c 100644 --- a/integrationTests/testProcessorNode.go +++ b/integrationTests/testProcessorNode.go @@ -31,17 +31,23 @@ import ( ed25519SingleSig "github.com/multiversx/mx-chain-crypto-go/signing/ed25519/singlesig" "github.com/multiversx/mx-chain-crypto-go/signing/mcl" mclsig "github.com/multiversx/mx-chain-crypto-go/signing/mcl/singlesig" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/multiversx/mx-chain-vm-common-go/parsers" + wasmConfig "github.com/multiversx/mx-chain-vm-go/config" + nodeFactory "github.com/multiversx/mx-chain-go/cmd/node/factory" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/enablers" "github.com/multiversx/mx-chain-go/common/errChan" "github.com/multiversx/mx-chain-go/common/forking" + "github.com/multiversx/mx-chain-go/common/graceperiod" "github.com/multiversx/mx-chain-go/common/ordering" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/spos/sposFactory" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/dataRetriever/factory/containers" requesterscontainer "github.com/multiversx/mx-chain-go/dataRetriever/factory/requestersContainer" "github.com/multiversx/mx-chain-go/dataRetriever/factory/resolverscontainer" @@ -78,6 +84,7 @@ import ( "github.com/multiversx/mx-chain-go/process/factory/shard" "github.com/multiversx/mx-chain-go/process/heartbeat/validator" "github.com/multiversx/mx-chain-go/process/interceptors" + interceptorsFactory "github.com/multiversx/mx-chain-go/process/interceptors/factory" processMock "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/peer" "github.com/multiversx/mx-chain-go/process/rating" @@ -104,10 +111,14 @@ import ( "github.com/multiversx/mx-chain-go/storage/txcache" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks" + cacheMocks "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" + consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" dblookupextMock "github.com/multiversx/mx-chain-go/testscommon/dblookupext" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" testFactory "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" "github.com/multiversx/mx-chain-go/testscommon/guardianMocks" @@ -126,9 +137,6 @@ import ( "github.com/multiversx/mx-chain-go/vm" vmProcess "github.com/multiversx/mx-chain-go/vm/process" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/multiversx/mx-chain-vm-common-go/parsers" - wasmConfig "github.com/multiversx/mx-chain-vm-go/config" ) var zero = big.NewInt(0) @@ -138,6 +146,9 @@ var hardforkPubKey = "153dae6cb3963260f309959bf285537b77ae16d82e9933147be7827f73 // TestHasher represents a sha256 hasher var TestHasher = sha256.NewSha256() +// TestEpochChangeGracePeriod represents the grace period for epoch change handler +var TestEpochChangeGracePeriod, _ = graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) + // TestTxSignHasher represents a sha3 legacy keccak 256 hasher var TestTxSignHasher = keccak.NewKeccak() @@ -302,6 +313,7 @@ type ArgTestProcessorNode struct { StatusMetrics external.StatusMetricsHandler WithPeersRatingHandler bool NodeOperationMode common.NodeOperation + Proofs dataRetriever.ProofsPool } // TestProcessorNode represents a container type of class used in integration tests @@ -328,6 +340,7 @@ type TestProcessorNode struct { TrieContainer common.TriesHolder BlockChain data.ChainHandler GenesisBlocks map[uint32]data.HeaderHandler + ProofsPool dataRetriever.ProofsPool EconomicsData *economics.TestEconomicsData RatingsData *rating.RatingsData @@ -402,15 +415,16 @@ type TestProcessorNode struct { ChainID []byte MinTransactionVersion uint32 - ExportHandler update.ExportHandler - WaitTime time.Duration - HistoryRepository dblookupext.HistoryRepository - EpochNotifier process.EpochNotifier - RoundNotifier process.RoundNotifier - EnableEpochs config.EnableEpochs - EnableRoundsHandler process.EnableRoundsHandler - EnableEpochsHandler common.EnableEpochsHandler - UseValidVmBlsSigVerifier bool + ExportHandler update.ExportHandler + WaitTime time.Duration + HistoryRepository dblookupext.HistoryRepository + EpochNotifier process.EpochNotifier + RoundNotifier process.RoundNotifier + EnableEpochs config.EnableEpochs + EnableRoundsHandler process.EnableRoundsHandler + EnableEpochsHandler common.EnableEpochsHandler + EpochChangeGracePeriodHandler common.EpochChangeGracePeriodHandler + UseValidVmBlsSigVerifier bool TransactionLogProcessor process.TransactionLogProcessor PeersRatingHandler p2p.PeersRatingHandler @@ -464,8 +478,8 @@ func newBaseTestProcessorNode(args ArgTestProcessorNode) *TestProcessorNode { var peersRatingMonitor p2p.PeersRatingMonitor peersRatingMonitor = &p2pmocks.PeersRatingMonitorStub{} if args.WithPeersRatingHandler { - topRatedCache := testscommon.NewCacherMock() - badRatedCache := testscommon.NewCacherMock() + topRatedCache := cacheMocks.NewCacherMock() + badRatedCache := cacheMocks.NewCacherMock() peersRatingHandler, _ = p2pFactory.NewPeersRatingHandler( p2pFactory.ArgPeersRatingHandler{ TopRatedCache: topRatedCache, @@ -504,37 +518,38 @@ func newBaseTestProcessorNode(args ArgTestProcessorNode) *TestProcessorNode { logsProcessor, _ := transactionLog.NewTxLogProcessor(transactionLog.ArgTxLogProcessor{Marshalizer: TestMarshalizer}) tpn := &TestProcessorNode{ - ShardCoordinator: shardCoordinator, - MainMessenger: messenger, - FullArchiveMessenger: fullArchiveMessenger, - NodeOperationMode: nodeOperationMode, - NodesCoordinator: nodesCoordinatorInstance, - ChainID: ChainID, - MinTransactionVersion: MinTransactionVersion, - NodesSetup: nodesSetup, - HistoryRepository: &dblookupextMock.HistoryRepositoryStub{}, - EpochNotifier: genericEpochNotifier, - RoundNotifier: genericRoundNotifier, - EnableRoundsHandler: enableRoundsHandler, - EnableEpochsHandler: enableEpochsHandler, - EpochProvider: &mock.CurrentNetworkEpochProviderStub{}, - WasmVMChangeLocker: &sync.RWMutex{}, - TransactionLogProcessor: logsProcessor, - Bootstrapper: mock.NewTestBootstrapperMock(), - PeersRatingHandler: peersRatingHandler, - MainPeerShardMapper: mock.NewNetworkShardingCollectorMock(), - FullArchivePeerShardMapper: mock.NewNetworkShardingCollectorMock(), - EnableEpochs: *epochsConfig, - UseValidVmBlsSigVerifier: args.WithBLSSigVerifier, - StorageBootstrapper: &mock.StorageBootstrapperMock{}, - BootstrapStorer: &mock.BoostrapStorerMock{}, - RatingsData: args.RatingsData, - EpochStartNotifier: args.EpochStartSubscriber, - GuardedAccountHandler: &guardianMocks.GuardedAccountHandlerStub{}, - AppStatusHandler: appStatusHandler, - PeersRatingMonitor: peersRatingMonitor, - TxExecutionOrderHandler: ordering.NewOrderedCollection(), - EpochStartTrigger: &mock.EpochStartTriggerStub{}, + ShardCoordinator: shardCoordinator, + MainMessenger: messenger, + FullArchiveMessenger: fullArchiveMessenger, + NodeOperationMode: nodeOperationMode, + NodesCoordinator: nodesCoordinatorInstance, + ChainID: ChainID, + MinTransactionVersion: MinTransactionVersion, + NodesSetup: nodesSetup, + HistoryRepository: &dblookupextMock.HistoryRepositoryStub{}, + EpochNotifier: genericEpochNotifier, + RoundNotifier: genericRoundNotifier, + EnableRoundsHandler: enableRoundsHandler, + EnableEpochsHandler: enableEpochsHandler, + EpochChangeGracePeriodHandler: TestEpochChangeGracePeriod, + EpochProvider: &mock.CurrentNetworkEpochProviderStub{}, + WasmVMChangeLocker: &sync.RWMutex{}, + TransactionLogProcessor: logsProcessor, + Bootstrapper: mock.NewTestBootstrapperMock(), + PeersRatingHandler: peersRatingHandler, + MainPeerShardMapper: mock.NewNetworkShardingCollectorMock(), + FullArchivePeerShardMapper: mock.NewNetworkShardingCollectorMock(), + EnableEpochs: *epochsConfig, + UseValidVmBlsSigVerifier: args.WithBLSSigVerifier, + StorageBootstrapper: &mock.StorageBootstrapperMock{}, + BootstrapStorer: &mock.BoostrapStorerMock{}, + RatingsData: args.RatingsData, + EpochStartNotifier: args.EpochStartSubscriber, + GuardedAccountHandler: &guardianMocks.GuardedAccountHandlerStub{}, + AppStatusHandler: appStatusHandler, + PeersRatingMonitor: peersRatingMonitor, + TxExecutionOrderHandler: ordering.NewOrderedCollection(), + EpochStartTrigger: &mock.EpochStartTriggerStub{}, } tpn.NodeKeys = args.NodeKeys @@ -557,11 +572,6 @@ func newBaseTestProcessorNode(args ArgTestProcessorNode) *TestProcessorNode { tpn.OwnAccount = CreateTestWalletAccount(shardCoordinator, args.TxSignPrivKeyShardId) } - tpn.HeaderSigVerifier = args.HeaderSigVerifier - if check.IfNil(tpn.HeaderSigVerifier) { - tpn.HeaderSigVerifier = &mock.HeaderSigVerifierStub{} - } - tpn.HeaderIntegrityVerifier = args.HeaderIntegrityVerifier if check.IfNil(tpn.HeaderIntegrityVerifier) { tpn.HeaderIntegrityVerifier = CreateHeaderIntegrityVerifier() @@ -571,9 +581,19 @@ func newBaseTestProcessorNode(args ArgTestProcessorNode) *TestProcessorNode { if !check.IfNil(args.DataPool) { tpn.DataPool = args.DataPool + tpn.ProofsPool = tpn.DataPool.Proofs() _ = messenger.SetThresholdMinConnectedPeers(minConnectedPeers) } + tpn.HeaderSigVerifier = args.HeaderSigVerifier + if check.IfNil(tpn.HeaderSigVerifier) { + tpn.HeaderSigVerifier = &consensusMocks.HeaderSigVerifierMock{ + VerifyHeaderProofCalled: func(proofHandler data.HeaderProofHandler) error { + return nil + }, + } + } + return tpn } @@ -881,7 +901,7 @@ func (tpn *TestProcessorNode) createFullSCQueryService(gasMap map[string]map[str argsBuiltIn.AutomaticCrawlerAddresses = GenerateOneAddressPerShard(argsBuiltIn.ShardCoordinator) builtInFuncFactory, _ := builtInFunctions.CreateBuiltInFunctionsFactory(argsBuiltIn) - smartContractsCache := testscommon.NewCacherMock() + smartContractsCache := cacheMocks.NewCacherMock() argsHook := hooks.ArgBlockChainHook{ Accounts: tpn.AccntState, @@ -1083,7 +1103,8 @@ func (tpn *TestProcessorNode) InitializeProcessors(gasMap map[string]map[string] } func (tpn *TestProcessorNode) initDataPools() { - tpn.DataPool = dataRetrieverMock.CreatePoolsHolder(1, tpn.ShardCoordinator.SelfId()) + tpn.ProofsPool = proofscache.NewProofsPool(3, 100) + tpn.DataPool = dataRetrieverMock.CreatePoolsHolderWithProofsPool(1, tpn.ShardCoordinator.SelfId(), tpn.ProofsPool) cacherCfg := storageunit.CacheConfig{Capacity: 10000, Type: storageunit.LRUCache, Shards: 1} suCache, _ := storageunit.NewCache(cacherCfg) tpn.WhiteListHandler, _ = interceptors.NewWhiteListDataVerifier(suCache) @@ -1179,21 +1200,25 @@ func (tpn *TestProcessorNode) initRatingsData() { func CreateRatingsData() *rating.RatingsData { ratingsConfig := config.RatingsConfig{ ShardChain: config.ShardChain{ - RatingSteps: config.RatingSteps{ - HoursToMaxRatingFromStartRating: 50, - ProposerValidatorImportance: 1, - ProposerDecreaseFactor: -4, - ValidatorDecreaseFactor: -4, - ConsecutiveMissedBlocksPenalty: 1.1, + RatingStepsByEpoch: []config.RatingSteps{ + { + HoursToMaxRatingFromStartRating: 50, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: 1.1, + }, }, }, MetaChain: config.MetaChain{ - RatingSteps: config.RatingSteps{ - HoursToMaxRatingFromStartRating: 50, - ProposerValidatorImportance: 1, - ProposerDecreaseFactor: -4, - ValidatorDecreaseFactor: -4, - ConsecutiveMissedBlocksPenalty: 1.1, + RatingStepsByEpoch: []config.RatingSteps{ + { + HoursToMaxRatingFromStartRating: 50, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: 1.1, + }, }, }, General: config.General{ @@ -1251,12 +1276,10 @@ func CreateRatingsData() *rating.RatingsData { } ratingDataArgs := rating.RatingsDataArg{ - Config: ratingsConfig, - ShardConsensusSize: 63, - MetaConsensusSize: 400, - ShardMinNodes: 400, - MetaMinNodes: 400, - RoundDurationMiliseconds: 6000, + Config: ratingsConfig, + ChainParametersHolder: &chainParameters.ChainParametersHolderMock{}, + EpochNotifier: &epochNotifier.EpochNotifierStub{}, + RoundDurationMilliseconds: 6000, } ratingsData, _ := rating.NewRatingsData(ratingDataArgs) @@ -1270,7 +1293,13 @@ func (tpn *TestProcessorNode) initInterceptors(heartbeatPk string) { tpn.EpochStartNotifier = notifier.NewEpochStartSubscriptionHandler() } - coreComponents := GetDefaultCoreComponents(CreateEnableEpochsConfig()) + if tpn.EpochNotifier == nil { + tpn.EpochNotifier = forking.NewGenericEpochNotifier() + } + if tpn.EnableEpochsHandler == nil { + tpn.EnableEpochsHandler, _ = enablers.NewEnableEpochsHandler(CreateEnableEpochsConfig(), tpn.EpochNotifier) + } + coreComponents := GetDefaultCoreComponents(tpn.EnableEpochsHandler, tpn.EpochNotifier) coreComponents.InternalMarshalizerField = TestMarshalizer coreComponents.TxMarshalizerField = TestTxSignMarshalizer coreComponents.HasherField = TestHasher @@ -1295,6 +1324,11 @@ func (tpn *TestProcessorNode) initInterceptors(heartbeatPk string) { cryptoComponents.BlKeyGen = tpn.OwnAccount.KeygenBlockSign cryptoComponents.TxKeyGen = tpn.OwnAccount.KeygenTxSign + interceptorDataVerifierArgs := interceptorsFactory.InterceptedDataVerifierFactoryArgs{ + CacheSpan: time.Second * 3, + CacheExpiry: time.Second * 10, + } + if tpn.ShardCoordinator.SelfId() == core.MetachainShardId { argsEpochStart := &metachain.ArgsNewMetaEpochStartTrigger{ GenesisTime: tpn.RoundHandler.TimeStamp(), @@ -1317,36 +1351,37 @@ func (tpn *TestProcessorNode) initInterceptors(heartbeatPk string) { coreComponents.HardforkTriggerPubKeyField = providedHardforkPk metaInterceptorContainerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: coreComponents, - CryptoComponents: cryptoComponents, - Accounts: tpn.AccntState, - ShardCoordinator: tpn.ShardCoordinator, - NodesCoordinator: tpn.NodesCoordinator, - MainMessenger: tpn.MainMessenger, - FullArchiveMessenger: tpn.FullArchiveMessenger, - Store: tpn.Storage, - DataPool: tpn.DataPool, - MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, - TxFeeHandler: tpn.EconomicsData, - BlockBlackList: tpn.BlockBlackListHandler, - HeaderSigVerifier: tpn.HeaderSigVerifier, - HeaderIntegrityVerifier: tpn.HeaderIntegrityVerifier, - ValidityAttester: tpn.BlockTracker, - EpochStartTrigger: tpn.EpochStartTrigger, - WhiteListHandler: tpn.WhiteListHandler, - WhiteListerVerifiedTxs: tpn.WhiteListerVerifiedTxs, - AntifloodHandler: &mock.NilAntifloodHandler{}, - ArgumentsParser: smartContract.NewArgumentParser(), - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - SizeCheckDelta: sizeCheckDelta, - RequestHandler: tpn.RequestHandler, - PeerSignatureHandler: &processMock.PeerSignatureHandlerStub{}, - SignaturesHandler: &processMock.SignaturesHandlerStub{}, - HeartbeatExpiryTimespanInSec: 30, - MainPeerShardMapper: tpn.MainPeerShardMapper, - FullArchivePeerShardMapper: tpn.FullArchivePeerShardMapper, - HardforkTrigger: tpn.HardforkTrigger, - NodeOperationMode: tpn.NodeOperationMode, + CoreComponents: coreComponents, + CryptoComponents: cryptoComponents, + Accounts: tpn.AccntState, + ShardCoordinator: tpn.ShardCoordinator, + NodesCoordinator: tpn.NodesCoordinator, + MainMessenger: tpn.MainMessenger, + FullArchiveMessenger: tpn.FullArchiveMessenger, + Store: tpn.Storage, + DataPool: tpn.DataPool, + MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, + TxFeeHandler: tpn.EconomicsData, + BlockBlackList: tpn.BlockBlackListHandler, + HeaderSigVerifier: tpn.HeaderSigVerifier, + HeaderIntegrityVerifier: tpn.HeaderIntegrityVerifier, + ValidityAttester: tpn.BlockTracker, + EpochStartTrigger: tpn.EpochStartTrigger, + WhiteListHandler: tpn.WhiteListHandler, + WhiteListerVerifiedTxs: tpn.WhiteListerVerifiedTxs, + AntifloodHandler: &mock.NilAntifloodHandler{}, + ArgumentsParser: smartContract.NewArgumentParser(), + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + SizeCheckDelta: sizeCheckDelta, + RequestHandler: tpn.RequestHandler, + PeerSignatureHandler: &processMock.PeerSignatureHandlerStub{}, + SignaturesHandler: &processMock.SignaturesHandlerStub{}, + HeartbeatExpiryTimespanInSec: 30, + MainPeerShardMapper: tpn.MainPeerShardMapper, + FullArchivePeerShardMapper: tpn.FullArchivePeerShardMapper, + HardforkTrigger: tpn.HardforkTrigger, + NodeOperationMode: tpn.NodeOperationMode, + InterceptedDataVerifierFactory: interceptorsFactory.NewInterceptedDataVerifierFactory(interceptorDataVerifierArgs), } interceptorContainerFactory, _ := interceptorscontainer.NewMetaInterceptorsContainerFactory(metaInterceptorContainerFactoryArgs) @@ -1385,37 +1420,39 @@ func (tpn *TestProcessorNode) initInterceptors(heartbeatPk string) { coreComponents.HardforkTriggerPubKeyField = providedHardforkPk shardIntereptorContainerFactoryArgs := interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: coreComponents, - CryptoComponents: cryptoComponents, - Accounts: tpn.AccntState, - ShardCoordinator: tpn.ShardCoordinator, - NodesCoordinator: tpn.NodesCoordinator, - MainMessenger: tpn.MainMessenger, - FullArchiveMessenger: tpn.FullArchiveMessenger, - Store: tpn.Storage, - DataPool: tpn.DataPool, - MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, - TxFeeHandler: tpn.EconomicsData, - BlockBlackList: tpn.BlockBlackListHandler, - HeaderSigVerifier: tpn.HeaderSigVerifier, - HeaderIntegrityVerifier: tpn.HeaderIntegrityVerifier, - ValidityAttester: tpn.BlockTracker, - EpochStartTrigger: tpn.EpochStartTrigger, - WhiteListHandler: tpn.WhiteListHandler, - WhiteListerVerifiedTxs: tpn.WhiteListerVerifiedTxs, - AntifloodHandler: &mock.NilAntifloodHandler{}, - ArgumentsParser: smartContract.NewArgumentParser(), - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - SizeCheckDelta: sizeCheckDelta, - RequestHandler: tpn.RequestHandler, - PeerSignatureHandler: &processMock.PeerSignatureHandlerStub{}, - SignaturesHandler: &processMock.SignaturesHandlerStub{}, - HeartbeatExpiryTimespanInSec: 30, - MainPeerShardMapper: tpn.MainPeerShardMapper, - FullArchivePeerShardMapper: tpn.FullArchivePeerShardMapper, - HardforkTrigger: tpn.HardforkTrigger, - NodeOperationMode: tpn.NodeOperationMode, + CoreComponents: coreComponents, + CryptoComponents: cryptoComponents, + Accounts: tpn.AccntState, + ShardCoordinator: tpn.ShardCoordinator, + NodesCoordinator: tpn.NodesCoordinator, + MainMessenger: tpn.MainMessenger, + FullArchiveMessenger: tpn.FullArchiveMessenger, + Store: tpn.Storage, + DataPool: tpn.DataPool, + MaxTxNonceDeltaAllowed: common.MaxTxNonceDeltaAllowed, + TxFeeHandler: tpn.EconomicsData, + BlockBlackList: tpn.BlockBlackListHandler, + HeaderSigVerifier: tpn.HeaderSigVerifier, + HeaderIntegrityVerifier: tpn.HeaderIntegrityVerifier, + ValidityAttester: tpn.BlockTracker, + EpochStartTrigger: tpn.EpochStartTrigger, + WhiteListHandler: tpn.WhiteListHandler, + WhiteListerVerifiedTxs: tpn.WhiteListerVerifiedTxs, + AntifloodHandler: &mock.NilAntifloodHandler{}, + ArgumentsParser: smartContract.NewArgumentParser(), + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + SizeCheckDelta: sizeCheckDelta, + RequestHandler: tpn.RequestHandler, + PeerSignatureHandler: &processMock.PeerSignatureHandlerStub{}, + SignaturesHandler: &processMock.SignaturesHandlerStub{}, + HeartbeatExpiryTimespanInSec: 30, + MainPeerShardMapper: tpn.MainPeerShardMapper, + FullArchivePeerShardMapper: tpn.FullArchivePeerShardMapper, + HardforkTrigger: tpn.HardforkTrigger, + NodeOperationMode: tpn.NodeOperationMode, + InterceptedDataVerifierFactory: interceptorsFactory.NewInterceptedDataVerifierFactory(interceptorDataVerifierArgs), } + interceptorContainerFactory, _ := interceptorscontainer.NewShardInterceptorsContainerFactory(shardIntereptorContainerFactoryArgs) tpn.MainInterceptorsContainer, tpn.FullArchiveInterceptorsContainer, err = interceptorContainerFactory.Create() @@ -1515,6 +1552,7 @@ func (tpn *TestProcessorNode) initRequesters() { FullArchivePreferredPeersHolder: &p2pmocks.PeersHolderStub{}, PeersRatingHandler: tpn.PeersRatingHandler, SizeCheckDelta: 0, + EnableEpochsHandler: tpn.EnableEpochsHandler, } if tpn.ShardCoordinator.SelfId() == core.MetachainShardId { @@ -1762,7 +1800,7 @@ func (tpn *TestProcessorNode) initInnerProcessors(gasMap map[string]map[string]u ) processedMiniBlocksTracker := processedMb.NewProcessedMiniBlocksTracker() - fact, _ := shard.NewPreProcessorsContainerFactory( + fact, err := shard.NewPreProcessorsContainerFactory( tpn.ShardCoordinator, tpn.Storage, TestMarshalizer, @@ -1786,6 +1824,9 @@ func (tpn *TestProcessorNode) initInnerProcessors(gasMap map[string]map[string]u processedMiniBlocksTracker, tpn.TxExecutionOrderHandler, ) + if err != nil { + panic(err.Error()) + } tpn.PreProcessorsContainer, _ = fact.Create() argsTransactionCoordinator := coordinator.ArgTransactionCoordinator{ @@ -2163,23 +2204,21 @@ func (tpn *TestProcessorNode) addMockVm(blockchainHook vmcommon.BlockchainHook) func (tpn *TestProcessorNode) initBlockProcessor() { var err error - if tpn.ShardCoordinator.SelfId() != core.MetachainShardId { - tpn.ForkDetector, _ = processSync.NewShardForkDetector(tpn.RoundHandler, tpn.BlockBlackListHandler, tpn.BlockTracker, tpn.NodesSetup.GetStartTime()) - } else { - tpn.ForkDetector, _ = processSync.NewMetaForkDetector(tpn.RoundHandler, tpn.BlockBlackListHandler, tpn.BlockTracker, tpn.NodesSetup.GetStartTime()) - } - accountsDb := make(map[state.AccountsDbIdentifier]state.AccountsAdapter) accountsDb[state.UserAccountsState] = tpn.AccntState accountsDb[state.PeerAccountsState] = tpn.PeerState - coreComponents := GetDefaultCoreComponents(CreateEnableEpochsConfig()) + if tpn.EpochNotifier == nil { + tpn.EpochNotifier = forking.NewGenericEpochNotifier() + } + if tpn.EnableEpochsHandler == nil { + tpn.EnableEpochsHandler, _ = enablers.NewEnableEpochsHandler(CreateEnableEpochsConfig(), tpn.EpochNotifier) + } + coreComponents := GetDefaultCoreComponents(tpn.EnableEpochsHandler, tpn.EpochNotifier) coreComponents.InternalMarshalizerField = TestMarshalizer coreComponents.HasherField = TestHasher coreComponents.Uint64ByteSliceConverterField = TestUint64Converter coreComponents.RoundHandlerField = tpn.RoundHandler - coreComponents.EnableEpochsHandlerField = tpn.EnableEpochsHandler - coreComponents.EpochNotifierField = tpn.EpochNotifier coreComponents.EconomicsDataField = tpn.EconomicsData coreComponents.RoundNotifierField = tpn.RoundNotifier @@ -2188,7 +2227,25 @@ func (tpn *TestProcessorNode) initBlockProcessor() { dataComponents.DataPool = tpn.DataPool dataComponents.BlockChain = tpn.BlockChain - bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator) + if tpn.ShardCoordinator.SelfId() != core.MetachainShardId { + tpn.ForkDetector, _ = processSync.NewShardForkDetector( + tpn.RoundHandler, + tpn.BlockBlackListHandler, + tpn.BlockTracker, + tpn.NodesSetup.GetStartTime(), + tpn.EnableEpochsHandler, + tpn.DataPool.Proofs()) + } else { + tpn.ForkDetector, _ = processSync.NewMetaForkDetector( + tpn.RoundHandler, + tpn.BlockBlackListHandler, + tpn.BlockTracker, + tpn.NodesSetup.GetStartTime(), + tpn.EnableEpochsHandler, + tpn.DataPool.Proofs()) + } + + bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator, tpn.EnableEpochsHandler) bootstrapComponents.HdrIntegrityVerifier = tpn.HeaderIntegrityVerifier statusComponents := GetDefaultStatusComponents() @@ -2473,8 +2530,13 @@ func (tpn *TestProcessorNode) initNode() { StatusMetricsField: tpn.StatusMetrics, AppStatusHandlerField: tpn.AppStatusHandler, } - - coreComponents := GetDefaultCoreComponents(CreateEnableEpochsConfig()) + if tpn.EpochNotifier == nil { + tpn.EpochNotifier = forking.NewGenericEpochNotifier() + } + if tpn.EnableEpochsHandler == nil { + tpn.EnableEpochsHandler, _ = enablers.NewEnableEpochsHandler(CreateEnableEpochsConfig(), tpn.EpochNotifier) + } + coreComponents := GetDefaultCoreComponents(tpn.EnableEpochsHandler, tpn.EpochNotifier) coreComponents.InternalMarshalizerField = TestMarshalizer coreComponents.VmMarshalizerField = TestVmMarshalizer coreComponents.TxMarshalizerField = TestTxSignMarshalizer @@ -2504,7 +2566,7 @@ func (tpn *TestProcessorNode) initNode() { dataComponents.DataPool = tpn.DataPool dataComponents.Store = tpn.Storage - bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator) + bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator, tpn.EnableEpochsHandler) processComponents := GetDefaultProcessComponents() processComponents.BlockProcess = tpn.BlockProcessor @@ -2687,7 +2749,10 @@ func (tpn *TestProcessorNode) LoadTxSignSkBytes(skBytes []byte) { } // ProposeBlock proposes a new block -func (tpn *TestProcessorNode) ProposeBlock(round uint64, nonce uint64) (data.BodyHandler, data.HeaderHandler, [][]byte) { +func (tpn *TestProcessorNode) ProposeBlock( + round uint64, + nonce uint64, +) (data.BodyHandler, data.HeaderHandler, [][]byte) { startTime := time.Now() maxTime := time.Second * 2 @@ -2699,6 +2764,7 @@ func (tpn *TestProcessorNode) ProposeBlock(round uint64, nonce uint64) (data.Bod blockHeader, err := tpn.BlockProcessor.CreateNewHeader(round, nonce) if err != nil { + log.Warn("blockHeader.CreateNewHeader", "error", err.Error()) return nil, nil, nil } @@ -2708,12 +2774,6 @@ func (tpn *TestProcessorNode) ProposeBlock(round uint64, nonce uint64) (data.Bod return nil, nil, nil } - err = blockHeader.SetPubKeysBitmap([]byte{1}) - if err != nil { - log.Warn("blockHeader.SetPubKeysBitmap", "error", err.Error()) - return nil, nil, nil - } - currHdr := tpn.BlockChain.GetCurrentBlockHeader() currHdrHash := tpn.BlockChain.GetCurrentBlockHeaderHash() if check.IfNil(currHdr) { @@ -2732,22 +2792,10 @@ func (tpn *TestProcessorNode) ProposeBlock(round uint64, nonce uint64) (data.Bod log.Warn("blockHeader.SetPrevRandSeed", "error", err.Error()) return nil, nil, nil } - sig := []byte("aggregated signature") - err = blockHeader.SetSignature(sig) - if err != nil { - log.Warn("blockHeader.SetSignature", "error", err.Error()) - return nil, nil, nil - } - - err = blockHeader.SetRandSeed(sig) - if err != nil { - log.Warn("blockHeader.SetRandSeed", "error", err.Error()) - return nil, nil, nil - } - err = blockHeader.SetLeaderSignature([]byte("leader sign")) + err = tpn.setBlockSignatures(blockHeader) if err != nil { - log.Warn("blockHeader.SetLeaderSignature", "error", err.Error()) + log.Warn("setBlockSignatures", "error", err.Error()) return nil, nil, nil } @@ -2796,6 +2844,40 @@ func (tpn *TestProcessorNode) ProposeBlock(round uint64, nonce uint64) (data.Bod return blockBody, blockHeader, txHashes } +func (tpn *TestProcessorNode) setBlockSignatures( + blockHeader data.HeaderHandler, +) error { + sig := []byte("aggregated signature") + pubKeysBitmap := []byte{1} + + err := blockHeader.SetRandSeed(sig) + if err != nil { + log.Warn("blockHeader.SetRandSeed", "error", err.Error()) + return err + } + + leaderSignature := []byte("leader signature") + err = blockHeader.SetLeaderSignature(leaderSignature) + if err != nil { + log.Warn("blockHeader.SetLeaderSignature", "error", err.Error()) + return err + } + + err = blockHeader.SetPubKeysBitmap(pubKeysBitmap) + if err != nil { + log.Warn("blockHeader.SetPubKeysBitmap", "error", err.Error()) + return err + } + + err = blockHeader.SetSignature(sig) + if err != nil { + log.Warn("blockHeader.SetSignature", "error", err.Error()) + return err + } + + return nil +} + // BroadcastBlock broadcasts the block and body to the connected peers func (tpn *TestProcessorNode) BroadcastBlock(body data.BodyHandler, header data.HeaderHandler, publicKey crypto.PublicKey) { _ = tpn.BroadcastMessenger.BroadcastBlock(body, header) @@ -2809,6 +2891,15 @@ func (tpn *TestProcessorNode) BroadcastBlock(body data.BodyHandler, header data. _ = tpn.BroadcastMessenger.BroadcastTransactions(transactions, pkBytes) } +// BroadcastProof broadcasts the proof to the connected peers +func (tpn *TestProcessorNode) BroadcastProof( + proof data.HeaderProofHandler, + publicKey crypto.PublicKey, +) { + pkBytes, _ := publicKey.ToByteArray() + _ = tpn.BroadcastMessenger.BroadcastEquivalentProof(proof, pkBytes) +} + // WhiteListBody will whitelist all miniblocks from the given body for all the given nodes func (tpn *TestProcessorNode) WhiteListBody(nodes []*TestProcessorNode, bodyHandler data.BodyHandler) { body, ok := bodyHandler.(*dataBlock.Body) @@ -3052,45 +3143,56 @@ func (tpn *TestProcessorNode) initRequestedItemsHandler() { func (tpn *TestProcessorNode) initBlockTracker() { argBaseTracker := track.ArgBaseTracker{ - Hasher: TestHasher, - HeaderValidator: tpn.HeaderValidator, - Marshalizer: TestMarshalizer, - RequestHandler: tpn.RequestHandler, - RoundHandler: tpn.RoundHandler, - ShardCoordinator: tpn.ShardCoordinator, - Store: tpn.Storage, - StartHeaders: tpn.GenesisBlocks, - PoolsHolder: tpn.DataPool, - WhitelistHandler: tpn.WhiteListHandler, - FeeHandler: tpn.EconomicsData, + Hasher: TestHasher, + HeaderValidator: tpn.HeaderValidator, + Marshalizer: TestMarshalizer, + RequestHandler: tpn.RequestHandler, + RoundHandler: tpn.RoundHandler, + ShardCoordinator: tpn.ShardCoordinator, + Store: tpn.Storage, + StartHeaders: tpn.GenesisBlocks, + PoolsHolder: tpn.DataPool, + WhitelistHandler: tpn.WhiteListHandler, + FeeHandler: tpn.EconomicsData, + EnableEpochsHandler: tpn.EnableEpochsHandler, + EpochChangeGracePeriodHandler: TestEpochChangeGracePeriod, + ProofsPool: tpn.DataPool.Proofs(), } + var err error if tpn.ShardCoordinator.SelfId() != core.MetachainShardId { arguments := track.ArgShardTracker{ ArgBaseTracker: argBaseTracker, } - tpn.BlockTracker, _ = track.NewShardBlockTrack(arguments) + tpn.BlockTracker, err = track.NewShardBlockTrack(arguments) + if err != nil { + panic(err.Error()) + } } else { arguments := track.ArgMetaTracker{ ArgBaseTracker: argBaseTracker, } - tpn.BlockTracker, _ = track.NewMetaBlockTrack(arguments) + tpn.BlockTracker, err = track.NewMetaBlockTrack(arguments) + if err != nil { + panic(err.Error()) + } } } func (tpn *TestProcessorNode) initHeaderValidator() { argsHeaderValidator := block.ArgsHeaderValidator{ - Hasher: TestHasher, - Marshalizer: TestMarshalizer, + Hasher: TestHasher, + Marshalizer: TestMarshalizer, + EnableEpochsHandler: tpn.EnableEpochsHandler, } tpn.HeaderValidator, _ = block.NewHeaderValidator(argsHeaderValidator) } func (tpn *TestProcessorNode) createHeartbeatWithHardforkTrigger() { - cacher := testscommon.NewCacherMock() + cacher := cacheMocks.NewCacherMock() psh, err := peerSignatureHandler.NewPeerSignatureHandler( cacher, tpn.OwnAccount.BlockSingleSigner, @@ -3281,14 +3383,12 @@ func CreateEnableEpochsConfig() config.EnableEpochs { SCProcessorV2EnableEpoch: UnreachableEpoch, FixRelayedBaseCostEnableEpoch: UnreachableEpoch, FixRelayedMoveBalanceToNonPayableSCEnableEpoch: UnreachableEpoch, + AndromedaEnableEpoch: UnreachableEpoch, } } // GetDefaultCoreComponents - -func GetDefaultCoreComponents(enableEpochsConfig config.EnableEpochs) *mock.CoreComponentsStub { - genericEpochNotifier := forking.NewGenericEpochNotifier() - enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(enableEpochsConfig, genericEpochNotifier) - +func GetDefaultCoreComponents(enableEpochsHandler common.EnableEpochsHandler, epochNotifier process.EpochNotifier) *mock.CoreComponentsStub { return &mock.CoreComponentsStub{ InternalMarshalizerField: TestMarshalizer, TxMarshalizerField: TestTxSignMarshalizer, @@ -3305,21 +3405,23 @@ func GetDefaultCoreComponents(enableEpochsConfig config.EnableEpochs) *mock.Core MinTransactionVersionCalled: func() uint32 { return 1 }, - StatusHandlerField: &statusHandlerMock.AppStatusHandlerStub{}, - WatchdogField: &testscommon.WatchdogMock{}, - AlarmSchedulerField: &testscommon.AlarmSchedulerStub{}, - SyncTimerField: &testscommon.SyncTimerStub{}, - RoundHandlerField: &testscommon.RoundHandlerMock{}, - EconomicsDataField: &economicsmocks.EconomicsHandlerMock{}, - RatingsDataField: &testscommon.RatingsInfoMock{}, - RaterField: &testscommon.RaterMock{}, - GenesisNodesSetupField: &genesisMocks.NodesSetupStub{}, - GenesisTimeField: time.Time{}, - EpochNotifierField: genericEpochNotifier, - EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, - TxVersionCheckField: versioning.NewTxVersionChecker(MinTransactionVersion), - ProcessStatusHandlerInternal: &testscommon.ProcessStatusHandlerStub{}, - EnableEpochsHandlerField: enableEpochsHandler, + StatusHandlerField: &statusHandlerMock.AppStatusHandlerStub{}, + WatchdogField: &testscommon.WatchdogMock{}, + AlarmSchedulerField: &testscommon.AlarmSchedulerStub{}, + SyncTimerField: &testscommon.SyncTimerStub{}, + RoundHandlerField: &testscommon.RoundHandlerMock{}, + EconomicsDataField: &economicsmocks.EconomicsHandlerMock{}, + RatingsDataField: &testscommon.RatingsInfoMock{}, + RaterField: &testscommon.RaterMock{}, + GenesisNodesSetupField: &genesisMocks.NodesSetupStub{}, + GenesisTimeField: time.Time{}, + EpochNotifierField: epochNotifier, + EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, + TxVersionCheckField: versioning.NewTxVersionChecker(MinTransactionVersion), + ProcessStatusHandlerInternal: &testscommon.ProcessStatusHandlerStub{}, + EnableEpochsHandlerField: enableEpochsHandler, + EpochChangeGracePeriodHandlerField: TestEpochChangeGracePeriod, + FieldsSizeCheckerField: &testscommon.FieldsSizeCheckerMock{}, } } @@ -3340,7 +3442,7 @@ func GetDefaultProcessComponents() *mock.ProcessComponentsStub { BlockProcess: &mock.BlockProcessorMock{}, BlackListHdl: &testscommon.TimeCacheStub{}, BootSore: &mock.BoostrapStorerMock{}, - HeaderSigVerif: &mock.HeaderSigVerifierStub{}, + HeaderSigVerif: &consensusMocks.HeaderSigVerifierMock{}, HeaderIntegrVerif: &mock.HeaderIntegrityVerifierStub{}, ValidatorStatistics: &testscommon.ValidatorStatisticsProcessorStub{}, ValidatorProvider: &stakingcommon.ValidatorsProviderStub{}, @@ -3438,10 +3540,17 @@ func GetDefaultStatusComponents() *mock.StatusComponentsStub { } // getDefaultBootstrapComponents - -func getDefaultBootstrapComponents(shardCoordinator sharding.Coordinator) *mainFactoryMocks.BootstrapComponentsStub { +func getDefaultBootstrapComponents(shardCoordinator sharding.Coordinator, handler common.EnableEpochsHandler) *mainFactoryMocks.BootstrapComponentsStub { var versionedHeaderFactory nodeFactory.VersionedHeaderFactory - headerVersionHandler := &testscommon.HeaderVersionHandlerStub{} + headerVersionHandler := &testscommon.HeaderVersionHandlerStub{ + GetVersionCalled: func(epoch uint32) string { + if handler.IsFlagEnabledInEpoch(common.AndromedaFlag, epoch) { + return "2" + } + return "1" + }, + } versionedHeaderFactory, _ = hdrFactory.NewShardHeaderFactory(headerVersionHandler) if shardCoordinator.SelfId() == core.MetachainShardId { versionedHeaderFactory, _ = hdrFactory.NewMetaHeaderFactory(headerVersionHandler) @@ -3554,9 +3663,9 @@ func getDefaultNodesSetup(maxShards, numNodes uint32, address []byte, pksBytes m func getDefaultNodesCoordinator(maxShards uint32, pksBytes map[uint32][]byte) nodesCoordinator.NodesCoordinator { return &shardingMocks.NodesCoordinatorStub{ - ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pksBytes[shardId], 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, GetAllValidatorsPublicKeysCalled: func() (map[uint32][][]byte, error) { keys := make(map[uint32][][]byte) @@ -3586,5 +3695,18 @@ func GetDefaultEnableEpochsConfig() *config.EnableEpochs { DynamicGasCostForDataTrieStorageLoadEnableEpoch: UnreachableEpoch, StakingV4Step1EnableEpoch: UnreachableEpoch, StakingV4Step2EnableEpoch: UnreachableEpoch, + StakingV4Step3EnableEpoch: UnreachableEpoch, + AndromedaEnableEpoch: UnreachableEpoch, + } +} + +// GetDefaultRoundsConfig - +func GetDefaultRoundsConfig() config.RoundConfig { + return config.RoundConfig{ + RoundActivations: map[string]config.ActivationRoundByName{ + "DisableAsyncCallV1": { + Round: "18446744073709551615", + }, + }, } } diff --git a/integrationTests/testProcessorNodeWithCoordinator.go b/integrationTests/testProcessorNodeWithCoordinator.go index 63392658a76..de1171a512f 100644 --- a/integrationTests/testProcessorNodeWithCoordinator.go +++ b/integrationTests/testProcessorNodeWithCoordinator.go @@ -9,10 +9,12 @@ import ( "github.com/multiversx/mx-chain-crypto-go/signing" "github.com/multiversx/mx-chain-crypto-go/signing/ed25519" "github.com/multiversx/mx-chain-crypto-go/signing/mcl" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests/mock" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/storage/cache" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" vic "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" @@ -61,8 +63,14 @@ func CreateProcessorNodesWithNodesCoordinator( for i, v := range validatorList { lruCache, _ := cache.NewLRUCache(10000) argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ - ShardConsensusGroupSize: shardConsensusGroupSize, - MetaConsensusGroupSize: metaConsensusGroupSize, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(shardConsensusGroupSize), + MetachainConsensusGroupSize: uint32(metaConsensusGroupSize), + }, nil + }, + }, Marshalizer: TestMarshalizer, Hasher: TestHasher, ShardIDAsObserver: shardId, diff --git a/integrationTests/testProcessorNodeWithMultisigner.go b/integrationTests/testProcessorNodeWithMultisigner.go index 42f08a62b39..9b2150d0f8c 100644 --- a/integrationTests/testProcessorNodeWithMultisigner.go +++ b/integrationTests/testProcessorNodeWithMultisigner.go @@ -17,6 +17,7 @@ import ( crypto "github.com/multiversx/mx-chain-crypto-go" mclmultisig "github.com/multiversx/mx-chain-crypto-go/signing/mcl/multisig" "github.com/multiversx/mx-chain-crypto-go/signing/multisig" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/epochStart/notifier" "github.com/multiversx/mx-chain-go/factory/peerSignatureHandler" @@ -29,8 +30,11 @@ import ( "github.com/multiversx/mx-chain-go/storage/cache" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" @@ -180,6 +184,7 @@ func CreateNodeWithBLSAndTxKeys( ScheduledMiniBlocksEnableEpoch: UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: UnreachableEpoch, RefactorPeersMiniBlocksEnableEpoch: UnreachableEpoch, + AndromedaEnableEpoch: UnreachableEpoch, } return CreateNode( @@ -241,6 +246,7 @@ func CreateNodesWithNodesCoordinatorFactory( StakingV4Step1EnableEpoch: UnreachableEpoch, StakingV4Step2EnableEpoch: UnreachableEpoch, StakingV4Step3EnableEpoch: UnreachableEpoch, + AndromedaEnableEpoch: UnreachableEpoch, } nodesMap := make(map[uint32][]*TestProcessorNode) @@ -399,10 +405,6 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( nodesMap := make(map[uint32][]*TestProcessorNode) shufflerArgs := &nodesCoordinator.NodesShufflerArgs{ - NodesShard: uint32(nodesPerShard), - NodesMeta: uint32(nbMetaNodes), - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, @@ -423,8 +425,18 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( for shardId, validatorList := range validatorsMap { consensusCache, _ := cache.NewLRUCache(10000) argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ - ShardConsensusGroupSize: shardConsensusGroupSize, - MetaConsensusGroupSize: metaConsensusGroupSize, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(shardConsensusGroupSize), + ShardMinNumNodes: uint32(nodesPerShard), + MetachainConsensusGroupSize: uint32(metaConsensusGroupSize), + MetachainMinNumNodes: uint32(nbMetaNodes), + Hysteresis: hysteresis, + Adaptivity: adaptivity, + }, nil + }, + }, Marshalizer: TestMarshalizer, Hasher: TestHasher, Shuffler: nodeShuffler, @@ -460,6 +472,10 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( SingleSigVerifier: signer, KeyGen: keyGen, FallbackHeaderValidator: &testscommon.FallBackHeaderValidatorStub{}, + EnableEpochsHandler: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), + HeadersPool: &mock.HeadersCacherStub{}, + ProofsPool: &dataRetriever.ProofsPoolMock{}, + StorageService: &genericMocks.ChainStorerMock{}, } headerSig, _ := headerCheck.NewHeaderSigVerifier(&args) @@ -483,6 +499,7 @@ func CreateNodesWithNodesCoordinatorAndHeaderSigVerifier( StakingV2EnableEpoch: UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: UnreachableEpoch, + AndromedaEnableEpoch: UnreachableEpoch, }, NodeKeys: cp.NodesKeys[shardId][i], NodesSetup: nodesSetup, @@ -544,8 +561,14 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( bootStorer := CreateMemUnit() lruCache, _ := cache.NewLRUCache(10000) argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ - ShardConsensusGroupSize: shardConsensusGroupSize, - MetaConsensusGroupSize: metaConsensusGroupSize, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(shardConsensusGroupSize), + MetachainConsensusGroupSize: uint32(metaConsensusGroupSize), + }, nil + }, + }, Marshalizer: TestMarshalizer, Hasher: TestHasher, Shuffler: nodeShuffler, @@ -595,6 +618,10 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( SingleSigVerifier: singleSigner, KeyGen: keyGenForBlocks, FallbackHeaderValidator: &testscommon.FallBackHeaderValidatorStub{}, + EnableEpochsHandler: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), + HeadersPool: &mock.HeadersCacherStub{}, + ProofsPool: &dataRetriever.ProofsPoolMock{}, + StorageService: &genericMocks.ChainStorerMock{}, } headerSig, _ := headerCheck.NewHeaderSigVerifier(&args) @@ -613,6 +640,7 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( StakingV2EnableEpoch: UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: UnreachableEpoch, + AndromedaEnableEpoch: UnreachableEpoch, }, NodeKeys: cp.NodesKeys[shardId][i], NodesSetup: nodesSetup, @@ -635,6 +663,15 @@ func CreateNodesWithNodesCoordinatorKeygenAndSingleSigner( return nodesMap } +// ProposeBlockData is a struct that holds some context data for the proposed block +type ProposeBlockData struct { + Body data.BodyHandler + Header data.HeaderHandler + Txs [][]byte + Leader *TestProcessorNode + ConsensusGroup []*TestProcessorNode +} + // ProposeBlockWithConsensusSignature proposes func ProposeBlockWithConsensusSignature( shardId uint32, @@ -643,39 +680,48 @@ func ProposeBlockWithConsensusSignature( nonce uint64, randomness []byte, epoch uint32, -) (data.BodyHandler, data.HeaderHandler, [][]byte, []*TestProcessorNode) { +) *ProposeBlockData { nodesCoordinatorInstance := nodesMap[shardId][0].NodesCoordinator - pubKeys, err := nodesCoordinatorInstance.GetConsensusValidatorsPublicKeys(randomness, round, shardId, epoch) + leaderPubKey, pubKeys, err := nodesCoordinatorInstance.GetConsensusValidatorsPublicKeys(randomness, round, shardId, epoch) if err != nil { log.Error("nodesCoordinator.GetConsensusValidatorsPublicKeys", "error", err) } // select nodes from map based on their pub keys - consensusNodes := selectTestNodesForPubKeys(nodesMap[shardId], pubKeys) + leaderNode, consensusNodes := selectTestNodesForPubKeys(nodesMap[shardId], leaderPubKey, pubKeys) // first node is block proposer - body, header, txHashes := consensusNodes[0].ProposeBlock(round, nonce) + body, header, txHashes := leaderNode.ProposeBlock(round, nonce) err = header.SetPrevRandSeed(randomness) if err != nil { log.Error("header.SetPrevRandSeed", "error", err) } - header = DoConsensusSigningOnBlock(header, consensusNodes, pubKeys) + header = DoConsensusSigningOnBlock(header, leaderNode, consensusNodes, pubKeys) - return body, header, txHashes, consensusNodes + return &ProposeBlockData{ + Body: body, + Header: header, + Txs: txHashes, + Leader: leaderNode, + ConsensusGroup: consensusNodes, + } } -func selectTestNodesForPubKeys(nodes []*TestProcessorNode, pubKeys []string) []*TestProcessorNode { +func selectTestNodesForPubKeys(nodes []*TestProcessorNode, leaderPubKey string, pubKeys []string) (*TestProcessorNode, []*TestProcessorNode) { selectedNodes := make([]*TestProcessorNode, len(pubKeys)) cntNodes := 0 - + var leaderNode *TestProcessorNode for i, pk := range pubKeys { - for _, node := range nodes { + for j, node := range nodes { pubKeyBytes, _ := node.NodeKeys.MainKey.Pk.ToByteArray() if bytes.Equal(pubKeyBytes, []byte(pk)) { - selectedNodes[i] = node + selectedNodes[i] = nodes[j] cntNodes++ } + if string(pubKeyBytes) == leaderPubKey { + leaderNode = nodes[j] + } } } @@ -683,12 +729,13 @@ func selectTestNodesForPubKeys(nodes []*TestProcessorNode, pubKeys []string) []* fmt.Println("Error selecting nodes from public keys") } - return selectedNodes + return leaderNode, selectedNodes } -// DoConsensusSigningOnBlock simulates a consensus aggregated signature on the provided block +// DoConsensusSigningOnBlock simulates a ConsensusGroup aggregated signature on the provided block func DoConsensusSigningOnBlock( blockHeader data.HeaderHandler, + leaderNode *TestProcessorNode, consensusNodes []*TestProcessorNode, pubKeys []string, ) data.HeaderHandler { @@ -719,7 +766,7 @@ func DoConsensusSigningOnBlock( pubKeysBytes := make([][]byte, len(consensusNodes)) sigShares := make([][]byte, len(consensusNodes)) - msig := consensusNodes[0].MultiSigner + msig := leaderNode.MultiSigner for i := 0; i < len(consensusNodes); i++ { pubKeysBytes[i] = []byte(pubKeys[i]) @@ -746,20 +793,14 @@ func DoConsensusSigningOnBlock( return blockHeader } -// AllShardsProposeBlock simulates each shard selecting a consensus group and proposing/broadcasting/committing a block +// AllShardsProposeBlock simulates each shard selecting a ConsensusGroup group and proposing/broadcasting/committing a block func AllShardsProposeBlock( round uint64, nonce uint64, nodesMap map[uint32][]*TestProcessorNode, -) ( - map[uint32]data.BodyHandler, - map[uint32]data.HeaderHandler, - map[uint32][]*TestProcessorNode, -) { +) map[uint32]*ProposeBlockData { - body := make(map[uint32]data.BodyHandler) - header := make(map[uint32]data.HeaderHandler) - consensusNodes := make(map[uint32][]*TestProcessorNode) + proposalData := make(map[uint32]*ProposeBlockData) newRandomness := make(map[uint32][]byte) nodesList := make([]*TestProcessorNode, 0) @@ -777,34 +818,36 @@ func AllShardsProposeBlock( // TODO: remove if start of epoch block needs to be validated by the new epoch nodes epoch := currentBlockHeader.GetEpoch() prevRandomness := currentBlockHeader.GetRandSeed() - body[i], header[i], _, consensusNodes[i] = ProposeBlockWithConsensusSignature( + proposalData[i] = ProposeBlockWithConsensusSignature( i, nodesMap, round, nonce, prevRandomness, epoch, ) - nodesMap[i][0].WhiteListBody(nodesList, body[i]) - newRandomness[i] = header[i].GetRandSeed() + proposalData[i].Leader.WhiteListBody(nodesList, proposalData[i].Body) + newRandomness[i] = proposalData[i].Header.GetRandSeed() } // propagate blocks for i := range nodesMap { - pk := consensusNodes[i][0].NodeKeys.MainKey.Pk - consensusNodes[i][0].BroadcastBlock(body[i], header[i], pk) - consensusNodes[i][0].CommitBlock(body[i], header[i]) + leader := proposalData[i].Leader + pk := proposalData[i].Leader.NodeKeys.MainKey.Pk + leader.BroadcastBlock(proposalData[i].Body, proposalData[i].Header, pk) + leader.CommitBlock(proposalData[i].Body, proposalData[i].Header) } time.Sleep(2 * StepDelay) - return body, header, consensusNodes + return proposalData } // SyncAllShardsWithRoundBlock enforces all nodes in each shard synchronizing the block for the given round func SyncAllShardsWithRoundBlock( t *testing.T, + proposalData map[uint32]*ProposeBlockData, nodesMap map[uint32][]*TestProcessorNode, - indexProposers map[uint32]int, round uint64, ) { - for shard, nodeList := range nodesMap { - SyncBlock(t, nodeList, []int{indexProposers[shard]}, round) + for shard, nodesList := range nodesMap { + proposal := proposalData[shard] + SyncBlock(t, nodesList, []*TestProcessorNode{proposal.Leader}, round) } time.Sleep(4 * StepDelay) } diff --git a/integrationTests/testProcessorNodeWithTestWebServer.go b/integrationTests/testProcessorNodeWithTestWebServer.go index b4e2490b900..c2a3a97426f 100644 --- a/integrationTests/testProcessorNodeWithTestWebServer.go +++ b/integrationTests/testProcessorNodeWithTestWebServer.go @@ -7,6 +7,10 @@ import ( "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" + "github.com/multiversx/mx-chain-vm-common-go/parsers" + datafield "github.com/multiversx/mx-chain-vm-common-go/parsers/dataField" + wasmConfig "github.com/multiversx/mx-chain-vm-go/config" + "github.com/multiversx/mx-chain-go/api/groups" "github.com/multiversx/mx-chain-go/api/shared" "github.com/multiversx/mx-chain-go/config" @@ -22,14 +26,12 @@ import ( "github.com/multiversx/mx-chain-go/process/transactionEvaluator" "github.com/multiversx/mx-chain-go/process/txstatus" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/state" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts/defaults" - "github.com/multiversx/mx-chain-vm-common-go/parsers" - datafield "github.com/multiversx/mx-chain-vm-common-go/parsers/dataField" - wasmConfig "github.com/multiversx/mx-chain-vm-go/config" ) // TestProcessorNodeWithTestWebServer represents a TestProcessorNode with a test web server @@ -178,7 +180,7 @@ func createFacadeComponents(tpn *TestProcessorNode) nodeFacade.ApiResolver { ShardCoordinator: tpn.ShardCoordinator, Marshalizer: TestMarshalizer, Hasher: TestHasher, - VMOutputCacher: &testscommon.CacherMock{}, + VMOutputCacher: &cache.CacherMock{}, DataFieldParser: dataFieldParser, BlockChainHook: tpn.BlockchainHook, } @@ -261,6 +263,8 @@ func createFacadeComponents(tpn *TestProcessorNode) nodeFacade.ApiResolver { AccountsRepository: &state.AccountsRepositoryStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + ProofsPool: tpn.ProofsPool, + BlockChain: tpn.BlockChain, } blockAPIHandler, err := blockAPI.CreateAPIBlockProcessor(argsBlockAPI) log.LogIfError(err) diff --git a/integrationTests/testSyncNode.go b/integrationTests/testSyncNode.go index b28d5e3f953..385034d6304 100644 --- a/integrationTests/testSyncNode.go +++ b/integrationTests/testSyncNode.go @@ -5,7 +5,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" - "github.com/multiversx/mx-chain-go/common" + + "github.com/multiversx/mx-chain-go/common/enablers" + "github.com/multiversx/mx-chain-go/common/forking" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/provider" @@ -45,27 +47,25 @@ func (tpn *TestProcessorNode) initBlockProcessorWithSync() { accountsDb[state.UserAccountsState] = tpn.AccntState accountsDb[state.PeerAccountsState] = tpn.PeerState - coreComponents := GetDefaultCoreComponents(CreateEnableEpochsConfig()) + if tpn.EpochNotifier == nil { + tpn.EpochNotifier = forking.NewGenericEpochNotifier() + } + if tpn.EnableEpochsHandler == nil { + tpn.EnableEpochsHandler, _ = enablers.NewEnableEpochsHandler(CreateEnableEpochsConfig(), tpn.EpochNotifier) + } + coreComponents := GetDefaultCoreComponents(tpn.EnableEpochsHandler, tpn.EpochNotifier) coreComponents.InternalMarshalizerField = TestMarshalizer coreComponents.HasherField = TestHasher coreComponents.Uint64ByteSliceConverterField = TestUint64Converter coreComponents.EpochNotifierField = tpn.EpochNotifier coreComponents.RoundNotifierField = tpn.RoundNotifier - coreComponents.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ - GetActivationEpochCalled: func(flag core.EnableEpochFlag) uint32 { - if flag == common.RefactorPeersMiniBlocksFlag { - return UnreachableEpoch - } - return 0 - }, - } dataComponents := GetDefaultDataComponents() dataComponents.Store = tpn.Storage dataComponents.DataPool = tpn.DataPool dataComponents.BlockChain = tpn.BlockChain - bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator) + bootstrapComponents := getDefaultBootstrapComponents(tpn.ShardCoordinator, tpn.EnableEpochsHandler) bootstrapComponents.HdrIntegrityVerifier = tpn.HeaderIntegrityVerifier statusComponents := GetDefaultStatusComponents() @@ -108,7 +108,13 @@ func (tpn *TestProcessorNode) initBlockProcessorWithSync() { } if tpn.ShardCoordinator.SelfId() == core.MetachainShardId { - tpn.ForkDetector, _ = sync.NewMetaForkDetector(tpn.RoundHandler, tpn.BlockBlackListHandler, tpn.BlockTracker, 0) + tpn.ForkDetector, _ = sync.NewMetaForkDetector( + tpn.RoundHandler, + tpn.BlockBlackListHandler, + tpn.BlockTracker, + 0, + tpn.EnableEpochsHandler, + tpn.DataPool.Proofs()) argumentsBase.ForkDetector = tpn.ForkDetector argumentsBase.TxCoordinator = &mock.TransactionCoordinatorMock{} arguments := block.ArgMetaProcessor{ @@ -129,7 +135,13 @@ func (tpn *TestProcessorNode) initBlockProcessorWithSync() { tpn.BlockProcessor, err = block.NewMetaProcessor(arguments) } else { - tpn.ForkDetector, _ = sync.NewShardForkDetector(tpn.RoundHandler, tpn.BlockBlackListHandler, tpn.BlockTracker, 0) + tpn.ForkDetector, _ = sync.NewShardForkDetector( + tpn.RoundHandler, + tpn.BlockBlackListHandler, + tpn.BlockTracker, + 0, + tpn.EnableEpochsHandler, + tpn.DataPool.Proofs()) argumentsBase.ForkDetector = tpn.ForkDetector argumentsBase.BlockChainHook = tpn.BlockchainHook argumentsBase.TxCoordinator = tpn.TxCoordinator @@ -176,6 +188,7 @@ func (tpn *TestProcessorNode) createShardBootstrapper() (TestBootstrapper, error ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, ProcessWaitTime: tpn.RoundHandler.TimeDuration(), RepopulateTokensSupplies: false, + EnableEpochsHandler: tpn.EnableEpochsHandler, } argsShardBootstrapper := sync.ArgShardBootstrapper{ @@ -222,6 +235,7 @@ func (tpn *TestProcessorNode) createMetaChainBootstrapper() (TestBootstrapper, e ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, ProcessWaitTime: tpn.RoundHandler.TimeDuration(), RepopulateTokensSupplies: false, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } argsMetaBootstrapper := sync.ArgMetaBootstrapper{ diff --git a/integrationTests/vm/delegation/changeOwner_test.go b/integrationTests/vm/delegation/changeOwner_test.go index 9caae7742c5..11a716a019c 100644 --- a/integrationTests/vm/delegation/changeOwner_test.go +++ b/integrationTests/vm/delegation/changeOwner_test.go @@ -6,12 +6,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" - "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/state" - "github.com/multiversx/mx-chain-go/testscommon" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/integrationTests" + "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/testscommon" ) var ( @@ -55,7 +56,7 @@ func TestDelegationChangeOwnerOnAccountHandler(t *testing.T) { // verify the new owner is still the delegator verifyDelegatorsStake(t, tpn, "getUserActiveStake", [][]byte{newOwner}, userAccount.AddressBytes(), big.NewInt(2000)) - //get the SC delegation account + // get the SC delegation account account, err := tpn.AccntState.LoadAccount(scAddress) require.Nil(t, err) @@ -92,7 +93,7 @@ func testDelegationChangeOwnerOnAccountHandler(t *testing.T, epochToTest uint32) changeOwner(t, tpn, firstOwner, newOwner, delegationScAddress) verifyDelegatorsStake(t, tpn, "getUserActiveStake", [][]byte{newOwner}, delegationScAddress, big.NewInt(2000)) - //get the SC delegation account + // get the SC delegation account account, err := tpn.AccntState.LoadAccount(delegationScAddress) require.Nil(t, err) diff --git a/integrationTests/vm/delegation/delegation_test.go b/integrationTests/vm/delegation/delegation_test.go index 9bae5235076..3b766314ccc 100644 --- a/integrationTests/vm/delegation/delegation_test.go +++ b/integrationTests/vm/delegation/delegation_test.go @@ -7,16 +7,16 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/integrationTests/multiShard/endOfEpoch" integrationTestsVm "github.com/multiversx/mx-chain-go/integrationTests/vm" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" "github.com/multiversx/mx-chain-go/vm" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestDelegationSystemSCWithValidatorStatisticsAndStakingPhase3p5(t *testing.T) { @@ -263,17 +263,14 @@ func processBlocks( blockToProduce uint64, nodesMap map[uint32][]*integrationTests.TestProcessorNode, ) (uint64, uint64) { - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode - for i := uint64(0); i < blockToProduce; i++ { for _, nodesSlice := range nodesMap { integrationTests.UpdateRound(nodesSlice, round) integrationTests.AddSelfNotarizedHeaderByMetachain(nodesSlice) } - _, _, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - indexesProposers := endOfEpoch.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) round++ nonce++ diff --git a/integrationTests/vm/esdt/common.go b/integrationTests/vm/esdt/common.go index 0d3a798d592..5aa943e551c 100644 --- a/integrationTests/vm/esdt/common.go +++ b/integrationTests/vm/esdt/common.go @@ -9,6 +9,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/esdt" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests" testVm "github.com/multiversx/mx-chain-go/integrationTests/vm" @@ -19,8 +22,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" "github.com/multiversx/mx-chain-go/vm" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/require" ) // GetESDTTokenData - @@ -91,12 +92,12 @@ func SetRolesWithSenderAccount(nodes []*integrationTests.TestProcessorNode, issu func DeployNonPayableSmartContract( t *testing.T, nodes []*integrationTests.TestProcessorNode, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, nonce *uint64, round *uint64, fileName string, ) []byte { - return DeployNonPayableSmartContractFromNode(t, nodes, 0, idxProposers, nonce, round, fileName) + return DeployNonPayableSmartContractFromNode(t, nodes, 0, leaders, nonce, round, fileName) } // DeployNonPayableSmartContractFromNode - @@ -104,7 +105,7 @@ func DeployNonPayableSmartContractFromNode( t *testing.T, nodes []*integrationTests.TestProcessorNode, idDeployer int, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, nonce *uint64, round *uint64, fileName string, @@ -121,7 +122,7 @@ func DeployNonPayableSmartContractFromNode( integrationTests.AdditionalGasLimit, ) - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, 4, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, *nonce, *round) scShardID := nodes[0].ShardCoordinator.ComputeId(scAddress) for _, node := range nodes { @@ -165,11 +166,12 @@ func CheckAddressHasTokens( } // CreateNodesAndPrepareBalances - -func CreateNodesAndPrepareBalances(numOfShards int) ([]*integrationTests.TestProcessorNode, []int) { +func CreateNodesAndPrepareBalances(numOfShards int) ([]*integrationTests.TestProcessorNode, []*integrationTests.TestProcessorNode) { enableEpochs := config.EnableEpochs{ OptimizeGasUsedInCrossMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } roundsConfig := testscommon.GetDefaultRoundsConfig() return CreateNodesAndPrepareBalancesWithEpochsAndRoundsConfig( @@ -180,7 +182,11 @@ func CreateNodesAndPrepareBalances(numOfShards int) ([]*integrationTests.TestPro } // CreateNodesAndPrepareBalancesWithEpochsAndRoundsConfig - -func CreateNodesAndPrepareBalancesWithEpochsAndRoundsConfig(numOfShards int, enableEpochs config.EnableEpochs, roundsConfig config.RoundConfig) ([]*integrationTests.TestProcessorNode, []int) { +func CreateNodesAndPrepareBalancesWithEpochsAndRoundsConfig( + numOfShards int, + enableEpochs config.EnableEpochs, + roundsConfig config.RoundConfig, +) ([]*integrationTests.TestProcessorNode, []*integrationTests.TestProcessorNode) { nodesPerShard := 1 numMetachainNodes := 1 @@ -198,14 +204,14 @@ func CreateNodesAndPrepareBalancesWithEpochsAndRoundsConfig(numOfShards int, ena }, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) - return nodes, idxProposers + return nodes, leaders } // IssueNFT - @@ -387,7 +393,7 @@ func PrepareFungibleTokensWithLocalBurnAndMint( t *testing.T, nodes []*integrationTests.TestProcessorNode, addressWithRoles []byte, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, round *uint64, nonce *uint64, ) string { @@ -396,7 +402,7 @@ func PrepareFungibleTokensWithLocalBurnAndMint( nodes, nodes[0].OwnAccount, addressWithRoles, - idxProposers, + leaders, round, nonce) } @@ -407,7 +413,7 @@ func PrepareFungibleTokensWithLocalBurnAndMintWithIssuerAccount( nodes []*integrationTests.TestProcessorNode, issuerAccount *integrationTests.TestWalletAccount, addressWithRoles []byte, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, round *uint64, nonce *uint64, ) string { @@ -415,7 +421,7 @@ func PrepareFungibleTokensWithLocalBurnAndMintWithIssuerAccount( time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, *nonce, *round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte("TKN"))) @@ -424,7 +430,7 @@ func PrepareFungibleTokensWithLocalBurnAndMintWithIssuerAccount( time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 5 - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, *nonce, *round) time.Sleep(time.Second) return tokenIdentifier diff --git a/integrationTests/vm/esdt/localFuncs/esdtLocalFunsSC_test.go b/integrationTests/vm/esdt/localFuncs/esdtLocalFunsSC_test.go index 742531fb801..a33b882a58c 100644 --- a/integrationTests/vm/esdt/localFuncs/esdtLocalFunsSC_test.go +++ b/integrationTests/vm/esdt/localFuncs/esdtLocalFunsSC_test.go @@ -7,17 +7,18 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/integrationTests" esdtCommon "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" - "github.com/stretchr/testify/assert" ) func TestESDTLocalMintAndBurnFromSC(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdtCommon.CreateNodesAndPrepareBalances(1) + nodes, leaders := esdtCommon.CreateNodesAndPrepareBalances(1) defer func() { for _, n := range nodes { @@ -33,9 +34,9 @@ func TestESDTLocalMintAndBurnFromSC(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../testdata/local-esdt-and-nft.wasm") + scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../testdata/local-esdt-and-nft.wasm") - esdtLocalMintAndBurnFromSCRunTestsAndAsserts(t, nodes, nodes[0].OwnAccount, scAddress, idxProposers, nonce, round) + esdtLocalMintAndBurnFromSCRunTestsAndAsserts(t, nodes, nodes[0].OwnAccount, scAddress, leaders, nonce, round) } func esdtLocalMintAndBurnFromSCRunTestsAndAsserts( @@ -43,11 +44,11 @@ func esdtLocalMintAndBurnFromSCRunTestsAndAsserts( nodes []*integrationTests.TestProcessorNode, ownerWallet *integrationTests.TestWalletAccount, scAddress []byte, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, nonce uint64, round uint64, ) { - tokenIdentifier := esdtCommon.PrepareFungibleTokensWithLocalBurnAndMintWithIssuerAccount(t, nodes, ownerWallet, scAddress, idxProposers, &nonce, &round) + tokenIdentifier := esdtCommon.PrepareFungibleTokensWithLocalBurnAndMintWithIssuerAccount(t, nodes, ownerWallet, scAddress, leaders, &nonce, &round) txData := []byte("localMint" + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + "@" + hex.EncodeToString(big.NewInt(100).Bytes())) @@ -72,7 +73,7 @@ func esdtLocalMintAndBurnFromSCRunTestsAndAsserts( time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 2 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, scAddress, nodes, []byte(tokenIdentifier), 0, 200) @@ -99,7 +100,7 @@ func esdtLocalMintAndBurnFromSCRunTestsAndAsserts( ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, scAddress, nodes, []byte(tokenIdentifier), 0, 100) @@ -109,7 +110,7 @@ func TestESDTSetRolesAndLocalMintAndBurnFromSC(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdtCommon.CreateNodesAndPrepareBalances(1) + nodes, leaders := esdtCommon.CreateNodesAndPrepareBalances(1) defer func() { for _, n := range nodes { @@ -125,7 +126,7 @@ func TestESDTSetRolesAndLocalMintAndBurnFromSC(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../testdata/local-esdt-and-nft.wasm") + scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../testdata/local-esdt-and-nft.wasm") issuePrice := big.NewInt(1000) txData := []byte("issueFungibleToken" + "@" + hex.EncodeToString([]byte("TOKEN")) + @@ -141,7 +142,7 @@ func TestESDTSetRolesAndLocalMintAndBurnFromSC(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte("TKR"))) @@ -157,7 +158,7 @@ func TestESDTSetRolesAndLocalMintAndBurnFromSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData = []byte("localMint" + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + @@ -180,7 +181,7 @@ func TestESDTSetRolesAndLocalMintAndBurnFromSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, scAddress, nodes, []byte(tokenIdentifier), 0, 201) @@ -205,7 +206,7 @@ func TestESDTSetRolesAndLocalMintAndBurnFromSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, scAddress, nodes, []byte(tokenIdentifier), 0, 101) @@ -215,7 +216,7 @@ func TestESDTSetTransferRoles(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdtCommon.CreateNodesAndPrepareBalances(2) + nodes, leaders := esdtCommon.CreateNodesAndPrepareBalances(2) defer func() { for _, n := range nodes { @@ -231,14 +232,14 @@ func TestESDTSetTransferRoles(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../testdata/use-module.wasm") + scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../testdata/use-module.wasm") nrRoundsToPropagateMultiShard := 12 - tokenIdentifier := esdtCommon.PrepareFungibleTokensWithLocalBurnAndMint(t, nodes, scAddress, idxProposers, &nonce, &round) + tokenIdentifier := esdtCommon.PrepareFungibleTokensWithLocalBurnAndMint(t, nodes, scAddress, leaders, &nonce, &round) esdtCommon.SetRoles(nodes, scAddress, []byte(tokenIdentifier), [][]byte{[]byte(core.ESDTRoleTransfer)}) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) destAddress := nodes[1].OwnAccount.Address @@ -256,7 +257,7 @@ func TestESDTSetTransferRoles(t *testing.T) { integrationTests.AdditionalGasLimit+core.MinMetaTxExtraGasCost, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, destAddress, nodes, []byte(tokenIdentifier), 0, amount) @@ -279,7 +280,7 @@ func TestESDTSetTransferRolesForwardAsyncCallFailsCross(t *testing.T) { } func testESDTWithTransferRoleAndForwarder(t *testing.T, numShards int) { - nodes, idxProposers := esdtCommon.CreateNodesAndPrepareBalances(numShards) + nodes, leaders := esdtCommon.CreateNodesAndPrepareBalances(numShards) defer func() { for _, n := range nodes { @@ -295,15 +296,15 @@ func testESDTWithTransferRoleAndForwarder(t *testing.T, numShards int) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddressA := esdtCommon.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../testdata/use-module.wasm") - scAddressB := esdtCommon.DeployNonPayableSmartContractFromNode(t, nodes, 1, idxProposers, &nonce, &round, "../testdata/use-module.wasm") + scAddressA := esdtCommon.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../testdata/use-module.wasm") + scAddressB := esdtCommon.DeployNonPayableSmartContractFromNode(t, nodes, 1, leaders, &nonce, &round, "../testdata/use-module.wasm") nrRoundsToPropagateMultiShard := 12 - tokenIdentifier := esdtCommon.PrepareFungibleTokensWithLocalBurnAndMint(t, nodes, scAddressA, idxProposers, &nonce, &round) + tokenIdentifier := esdtCommon.PrepareFungibleTokensWithLocalBurnAndMint(t, nodes, scAddressA, leaders, &nonce, &round) esdtCommon.SetRoles(nodes, scAddressA, []byte(tokenIdentifier), [][]byte{[]byte(core.ESDTRoleTransfer)}) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) amount := int64(100) @@ -319,7 +320,7 @@ func testESDTWithTransferRoleAndForwarder(t *testing.T, numShards int) { integrationTests.AdditionalGasLimit+core.MinMetaTxExtraGasCost, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 15, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 15, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, scAddressB, nodes, []byte(tokenIdentifier), 0, 0) @@ -344,7 +345,7 @@ func TestAsyncCallsAndCallBacksArgumentsCross(t *testing.T) { } func testAsyncCallAndCallBacksArguments(t *testing.T, numShards int) { - nodes, idxProposers := esdtCommon.CreateNodesAndPrepareBalances(numShards) + nodes, leaders := esdtCommon.CreateNodesAndPrepareBalances(numShards) defer func() { for _, n := range nodes { n.Close() @@ -359,8 +360,8 @@ func testAsyncCallAndCallBacksArguments(t *testing.T, numShards int) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddressA := esdtCommon.DeployNonPayableSmartContractFromNode(t, nodes, 0, idxProposers, &nonce, &round, "forwarder.wasm") - scAddressB := esdtCommon.DeployNonPayableSmartContractFromNode(t, nodes, 1, idxProposers, &nonce, &round, "vault.wasm") + scAddressA := esdtCommon.DeployNonPayableSmartContractFromNode(t, nodes, 0, leaders, &nonce, &round, "forwarder.wasm") + scAddressB := esdtCommon.DeployNonPayableSmartContractFromNode(t, nodes, 1, leaders, &nonce, &round, "vault.wasm") txData := txDataBuilder.NewBuilder() txData.Clear().Func("echo_args_async").Bytes(scAddressB).Str("AA").Str("BB") @@ -374,7 +375,7 @@ func testAsyncCallAndCallBacksArguments(t *testing.T, numShards int) { integrationTests.AdditionalGasLimit+core.MinMetaTxExtraGasCost, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 15, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 15, nonce, round) time.Sleep(time.Second) callbackArgs := append([]byte("success"), []byte{0}...) @@ -391,7 +392,7 @@ func testAsyncCallAndCallBacksArguments(t *testing.T, numShards int) { integrationTests.AdditionalGasLimit+core.MinMetaTxExtraGasCost, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 15, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 15, nonce, round) time.Sleep(time.Second) checkDataFromAccountAndKey(t, nodes, scAddressA, []byte("callbackStorage"), append([]byte("success"), []byte{0}...)) diff --git a/integrationTests/vm/esdt/multisign/esdtMultisign_test.go b/integrationTests/vm/esdt/multisign/esdtMultisign_test.go index fd8e0b6fbb8..8a82988663a 100644 --- a/integrationTests/vm/esdt/multisign/esdtMultisign_test.go +++ b/integrationTests/vm/esdt/multisign/esdtMultisign_test.go @@ -8,14 +8,15 @@ import ( "testing" "time" - "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" - "github.com/multiversx/mx-chain-go/process" - "github.com/multiversx/mx-chain-go/vm" logger "github.com/multiversx/mx-chain-logger-go" vmcommon "github.com/multiversx/mx-chain-vm-common-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/integrationTests" + "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/vm" ) var vmType = []byte{5, 0} @@ -37,11 +38,11 @@ func TestESDTTransferWithMultisig(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -63,7 +64,7 @@ func TestESDTTransferWithMultisig(t *testing.T) { time.Sleep(time.Second) numRoundsToPropagateIntraShard := 2 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateIntraShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateIntraShard, nonce, round) time.Sleep(time.Second) // ----- issue ESDT token @@ -72,7 +73,7 @@ func TestESDTTransferWithMultisig(t *testing.T) { proposeIssueTokenAndTransferFunds(nodes, multisignContractAddress, initalSupply, 0, ticker) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateIntraShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateIntraShard, nonce, round) time.Sleep(time.Second) actionID := getActionID(t, nodes, multisignContractAddress) @@ -82,13 +83,13 @@ func TestESDTTransferWithMultisig(t *testing.T) { time.Sleep(time.Second) numRoundsToPropagateCrossShard := 10 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateCrossShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateCrossShard, nonce, round) time.Sleep(time.Second) performActionID(nodes, multisignContractAddress, actionID, 0) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateCrossShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateCrossShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := integrationTests.GetTokenIdentifier(nodes, []byte(ticker)) @@ -102,7 +103,7 @@ func TestESDTTransferWithMultisig(t *testing.T) { proposeTransferToken(nodes, multisignContractAddress, transferValue, 0, destinationAddress, tokenIdentifier) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateIntraShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateIntraShard, nonce, round) time.Sleep(time.Second) actionID = getActionID(t, nodes, multisignContractAddress) @@ -111,13 +112,13 @@ func TestESDTTransferWithMultisig(t *testing.T) { boardMembersSignActionID(nodes, multisignContractAddress, actionID, 1, 2) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateCrossShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateCrossShard, nonce, round) time.Sleep(time.Second) performActionID(nodes, multisignContractAddress, actionID, 0) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, numRoundsToPropagateCrossShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, numRoundsToPropagateCrossShard, nonce, round) time.Sleep(time.Second) expectedBalance := big.NewInt(0).Set(initalSupply) diff --git a/integrationTests/vm/esdt/nft/common.go b/integrationTests/vm/esdt/nft/common.go index 6df8dc7dd69..23cd837ba3a 100644 --- a/integrationTests/vm/esdt/nft/common.go +++ b/integrationTests/vm/esdt/nft/common.go @@ -8,9 +8,10 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" - "github.com/stretchr/testify/require" ) // NftArguments - @@ -70,7 +71,7 @@ func CheckNftData( func PrepareNFTWithRoles( t *testing.T, nodes []*integrationTests.TestProcessorNode, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, nftCreator *integrationTests.TestProcessorNode, round *uint64, nonce *uint64, @@ -82,7 +83,7 @@ func PrepareNFTWithRoles( time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 10 - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, *nonce, *round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte("SFT"))) @@ -91,7 +92,7 @@ func PrepareNFTWithRoles( esdt.SetRoles(nodes, nftCreator.OwnAccount.Address, []byte(tokenIdentifier), roles) time.Sleep(time.Second) - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, *nonce, *round) time.Sleep(time.Second) nftMetaData := NftArguments{ @@ -105,7 +106,7 @@ func PrepareNFTWithRoles( CreateNFT([]byte(tokenIdentifier), nftCreator, nodes, &nftMetaData) time.Sleep(time.Second) - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, 3, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, *nonce, *round) time.Sleep(time.Second) CheckNftData( diff --git a/integrationTests/vm/esdt/nft/esdtNFT/esdtNft_test.go b/integrationTests/vm/esdt/nft/esdtNFT/esdtNft_test.go index a1db92372bd..c35e513b357 100644 --- a/integrationTests/vm/esdt/nft/esdtNFT/esdtNft_test.go +++ b/integrationTests/vm/esdt/nft/esdtNFT/esdtNft_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt/nft" @@ -29,11 +30,11 @@ func TestESDTNonFungibleTokenCreateAndBurn(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -59,7 +60,7 @@ func TestESDTNonFungibleTokenCreateAndBurn(t *testing.T) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodes[1], &round, &nonce, @@ -85,7 +86,7 @@ func TestESDTNonFungibleTokenCreateAndBurn(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // the token data is removed from trie if the quantity is 0, so we should not find it @@ -116,11 +117,11 @@ func TestESDTSemiFungibleTokenCreateAddAndBurn(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -148,7 +149,7 @@ func TestESDTSemiFungibleTokenCreateAddAndBurn(t *testing.T) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodes[1], &round, &nonce, @@ -174,7 +175,7 @@ func TestESDTSemiFungibleTokenCreateAddAndBurn(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity += quantityToAdd @@ -190,7 +191,7 @@ func TestESDTSemiFungibleTokenCreateAddAndBurn(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nft.CheckNftData( @@ -219,7 +220,7 @@ func TestESDTSemiFungibleTokenCreateAddAndBurn(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity -= quantityToBurn @@ -249,11 +250,11 @@ func TestESDTNonFungibleTokenTransferSelfShard(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -278,7 +279,7 @@ func TestESDTNonFungibleTokenTransferSelfShard(t *testing.T) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodes[1], &round, &nonce, @@ -315,7 +316,7 @@ func TestESDTNonFungibleTokenTransferSelfShard(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // check that the new address owns the NFT @@ -357,11 +358,11 @@ func TestESDTSemiFungibleTokenTransferCrossShard(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -398,7 +399,7 @@ func TestESDTSemiFungibleTokenTransferCrossShard(t *testing.T) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodeInDifferentShard, &round, &nonce, @@ -424,7 +425,7 @@ func TestESDTSemiFungibleTokenTransferCrossShard(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity += quantityToAdd @@ -440,7 +441,7 @@ func TestESDTSemiFungibleTokenTransferCrossShard(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nft.CheckNftData( @@ -469,7 +470,7 @@ func TestESDTSemiFungibleTokenTransferCrossShard(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 11 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity = initialQuantity + quantityToAdd - quantityToTransfer @@ -510,11 +511,11 @@ func TestESDTSemiFungibleTokenTransferToSystemScAddressShouldReceiveBack(t *test numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -542,7 +543,7 @@ func TestESDTSemiFungibleTokenTransferToSystemScAddressShouldReceiveBack(t *test tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodes[0], &round, &nonce, @@ -568,7 +569,7 @@ func TestESDTSemiFungibleTokenTransferToSystemScAddressShouldReceiveBack(t *test time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity += quantityToAdd @@ -584,7 +585,7 @@ func TestESDTSemiFungibleTokenTransferToSystemScAddressShouldReceiveBack(t *test time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nft.CheckNftData( @@ -613,7 +614,7 @@ func TestESDTSemiFungibleTokenTransferToSystemScAddressShouldReceiveBack(t *test time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity = 0 // make sure that the ESDT SC address didn't receive the token @@ -640,7 +641,7 @@ func TestESDTSemiFungibleTokenTransferToSystemScAddressShouldReceiveBack(t *test } func testNFTSendCreateRole(t *testing.T, numOfShards int) { - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(numOfShards) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(numOfShards) defer func() { for _, n := range nodes { @@ -665,7 +666,7 @@ func testNFTSendCreateRole(t *testing.T, numOfShards int) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nftCreator, &round, &nonce, @@ -698,7 +699,7 @@ func testNFTSendCreateRole(t *testing.T, numOfShards int) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 20 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nft.CreateNFT( @@ -710,7 +711,7 @@ func testNFTSendCreateRole(t *testing.T, numOfShards int) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 2 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nft.CheckNftData( @@ -766,11 +767,11 @@ func testESDTSemiFungibleTokenTransferRole(t *testing.T, numOfShards int) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -808,7 +809,7 @@ func testESDTSemiFungibleTokenTransferRole(t *testing.T, numOfShards int) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodeInDifferentShard, &round, &nonce, @@ -834,7 +835,7 @@ func testESDTSemiFungibleTokenTransferRole(t *testing.T, numOfShards int) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity += quantityToAdd @@ -850,7 +851,7 @@ func testESDTSemiFungibleTokenTransferRole(t *testing.T, numOfShards int) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 5 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nft.CheckNftData( @@ -879,7 +880,7 @@ func testESDTSemiFungibleTokenTransferRole(t *testing.T, numOfShards int) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 11 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity = initialQuantity + quantityToAdd - quantityToTransfer @@ -920,11 +921,11 @@ func TestESDTSFTWithEnhancedTransferRole(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -954,7 +955,7 @@ func TestESDTSFTWithEnhancedTransferRole(t *testing.T) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, tokenIssuer, &round, &nonce, @@ -980,7 +981,7 @@ func TestESDTSFTWithEnhancedTransferRole(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 2 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity += quantityToAdd @@ -1013,7 +1014,7 @@ func TestESDTSFTWithEnhancedTransferRole(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity = initialQuantity + quantityToAdd - int64(len(nodes)-1)*quantityToTransfer @@ -1056,7 +1057,7 @@ func TestESDTSFTWithEnhancedTransferRole(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) nftMetaData.Quantity = initialQuantity + quantityToAdd @@ -1101,7 +1102,7 @@ func TestNFTTransferCreateAndSetRolesCrossShard(t *testing.T) { } func testNFTTransferCreateRoleAndStop(t *testing.T, numOfShards int) { - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(numOfShards) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(numOfShards) defer func() { for _, n := range nodes { @@ -1126,7 +1127,7 @@ func testNFTTransferCreateRoleAndStop(t *testing.T, numOfShards int) { tokenIdentifier, nftMetaData := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nftCreator, &round, &nonce, @@ -1158,7 +1159,7 @@ func testNFTTransferCreateRoleAndStop(t *testing.T, numOfShards int) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 15, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 15, nonce, round) time.Sleep(time.Second) // stopNFTCreate @@ -1173,7 +1174,7 @@ func testNFTTransferCreateRoleAndStop(t *testing.T, numOfShards int) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) // setCreateRole @@ -1190,7 +1191,7 @@ func testNFTTransferCreateRoleAndStop(t *testing.T, numOfShards int) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 20, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 20, nonce, round) time.Sleep(time.Second) newNFTMetaData := nft.NftArguments{ @@ -1210,7 +1211,7 @@ func testNFTTransferCreateRoleAndStop(t *testing.T, numOfShards int) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) // we check that old data remains on NONCE 1 - as creation must return failure diff --git a/integrationTests/vm/esdt/nft/esdtNFTSCs/esdtNFTSCs_test.go b/integrationTests/vm/esdt/nft/esdtNFTSCs/esdtNFTSCs_test.go index 534c1c7435e..a1c3b524c9f 100644 --- a/integrationTests/vm/esdt/nft/esdtNFTSCs/esdtNFTSCs_test.go +++ b/integrationTests/vm/esdt/nft/esdtNFTSCs/esdtNFTSCs_test.go @@ -7,17 +7,18 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt/nft" - "github.com/stretchr/testify/require" ) func TestESDTNFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(1) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(1) defer func() { for _, n := range nodes { @@ -33,7 +34,7 @@ func TestESDTNFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress, tokenIdentifier := deployAndIssueNFTSFTThroughSC(t, nodes, idxProposers, &nonce, &round, "nftIssue", "@03@05") + scAddress, tokenIdentifier := deployAndIssueNFTSFTThroughSC(t, nodes, leaders, &nonce, &round, "nftIssue", "@03@05") txData := []byte("nftCreate" + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + "@" + hex.EncodeToString(big.NewInt(1).Bytes()) + "@" + hex.EncodeToString([]byte("name")) + @@ -65,7 +66,7 @@ func TestESDTNFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 3, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, scAddress, nodes, []byte(tokenIdentifier), 3, big.NewInt(1)) @@ -92,7 +93,7 @@ func TestESDTNFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 3, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, scAddress, nodes, []byte(tokenIdentifier), 2, big.NewInt(1)) @@ -123,7 +124,7 @@ func TestESDTNFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 3, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, destinationAddress, nodes, []byte(tokenIdentifier), 2, big.NewInt(1)) @@ -136,7 +137,7 @@ func TestESDTSemiFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(1) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(1) defer func() { for _, n := range nodes { @@ -152,7 +153,7 @@ func TestESDTSemiFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress, tokenIdentifier := deployAndIssueNFTSFTThroughSC(t, nodes, idxProposers, &nonce, &round, "sftIssue", "@03@04@05") + scAddress, tokenIdentifier := deployAndIssueNFTSFTThroughSC(t, nodes, leaders, &nonce, &round, "sftIssue", "@03@04@05") txData := []byte("nftCreate" + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + "@" + hex.EncodeToString(big.NewInt(1).Bytes()) + "@" + hex.EncodeToString([]byte("name")) + @@ -179,7 +180,7 @@ func TestESDTSemiFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, scAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(11)) @@ -204,7 +205,7 @@ func TestESDTSemiFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, scAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(9)) @@ -234,7 +235,7 @@ func TestESDTSemiFTIssueCreateBurnSendViaAsyncViaExecuteOnSC(t *testing.T) { integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, destinationAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(9)) @@ -245,7 +246,7 @@ func TestESDTTransferNFTBetweenContractsAcceptAndNotAcceptWithRevert(t *testing. if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(1) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(1) defer func() { for _, n := range nodes { @@ -261,7 +262,7 @@ func TestESDTTransferNFTBetweenContractsAcceptAndNotAcceptWithRevert(t *testing. round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress, tokenIdentifier := deployAndIssueNFTSFTThroughSC(t, nodes, idxProposers, &nonce, &round, "nftIssue", "@03@05") + scAddress, tokenIdentifier := deployAndIssueNFTSFTThroughSC(t, nodes, leaders, &nonce, &round, "nftIssue", "@03@05") txData := []byte("nftCreate" + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + "@" + hex.EncodeToString(big.NewInt(1).Bytes()) + "@" + hex.EncodeToString([]byte("name")) + @@ -285,13 +286,13 @@ func TestESDTTransferNFTBetweenContractsAcceptAndNotAcceptWithRevert(t *testing. ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, scAddress, nodes, []byte(tokenIdentifier), 2, big.NewInt(1)) checkAddressHasNft(t, scAddress, scAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(1)) - destinationSCAddress := esdt.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../../testdata/nft-receiver.wasm") + destinationSCAddress := esdt.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../../testdata/nft-receiver.wasm") txData = []byte("transferNftViaAsyncCall" + "@" + hex.EncodeToString(destinationSCAddress) + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + "@" + hex.EncodeToString(big.NewInt(1).Bytes()) + "@" + hex.EncodeToString(big.NewInt(1).Bytes()) + "@" + hex.EncodeToString([]byte("wrongFunctionToCall"))) @@ -316,7 +317,7 @@ func TestESDTTransferNFTBetweenContractsAcceptAndNotAcceptWithRevert(t *testing. integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, destinationSCAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(0)) @@ -348,7 +349,7 @@ func TestESDTTransferNFTBetweenContractsAcceptAndNotAcceptWithRevert(t *testing. integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, scAddress, destinationSCAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(1)) @@ -361,7 +362,7 @@ func TestESDTTransferNFTToSCIntraShard(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(1) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(1) defer func() { for _, n := range nodes { @@ -384,7 +385,7 @@ func TestESDTTransferNFTToSCIntraShard(t *testing.T) { tokenIdentifier, _ := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodes[0], &round, &nonce, @@ -395,7 +396,7 @@ func TestESDTTransferNFTToSCIntraShard(t *testing.T) { nonceArg := hex.EncodeToString(big.NewInt(0).SetUint64(1).Bytes()) quantityToTransfer := hex.EncodeToString(big.NewInt(1).Bytes()) - destinationSCAddress := esdt.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../../testdata/nft-receiver.wasm") + destinationSCAddress := esdt.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../../testdata/nft-receiver.wasm") txData := core.BuiltInFunctionESDTNFTTransfer + "@" + hex.EncodeToString([]byte(tokenIdentifier)) + "@" + nonceArg + "@" + quantityToTransfer + "@" + hex.EncodeToString(destinationSCAddress) + "@" + hex.EncodeToString([]byte("acceptAndReturnCallData")) integrationTests.CreateAndSendTransaction( @@ -408,7 +409,7 @@ func TestESDTTransferNFTToSCIntraShard(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 3, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 3, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, nodes[0].OwnAccount.Address, destinationSCAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(1)) @@ -418,7 +419,7 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { if testing.Short() { t.Skip("this is not a short test") } - nodes, idxProposers := esdt.CreateNodesAndPrepareBalances(2) + nodes, leaders := esdt.CreateNodesAndPrepareBalances(2) defer func() { for _, n := range nodes { @@ -434,7 +435,7 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - destinationSCAddress := esdt.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../../testdata/nft-receiver.wasm") + destinationSCAddress := esdt.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../../testdata/nft-receiver.wasm") destinationSCShardID := nodes[0].ShardCoordinator.ComputeId(destinationSCAddress) @@ -454,7 +455,7 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { tokenIdentifier, _ := nft.PrepareNFTWithRoles( t, nodes, - idxProposers, + leaders, nodeFromOtherShard, &round, &nonce, @@ -478,7 +479,7 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, nodeFromOtherShard.OwnAccount.Address, destinationSCAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(1)) @@ -495,7 +496,7 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, nodeFromOtherShard.OwnAccount.Address, destinationSCAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(1)) @@ -512,7 +513,7 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) time.Sleep(time.Second) checkAddressHasNft(t, nodeFromOtherShard.OwnAccount.Address, destinationSCAddress, nodes, []byte(tokenIdentifier), 1, big.NewInt(1)) @@ -521,13 +522,13 @@ func TestESDTTransferNFTToSCCrossShard(t *testing.T) { func deployAndIssueNFTSFTThroughSC( t *testing.T, nodes []*integrationTests.TestProcessorNode, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, nonce *uint64, round *uint64, issueFunc string, rolesEncoded string, ) ([]byte, string) { - scAddress := esdt.DeployNonPayableSmartContract(t, nodes, idxProposers, nonce, round, "../../testdata/local-esdt-and-nft.wasm") + scAddress := esdt.DeployNonPayableSmartContract(t, nodes, leaders, nonce, round, "../../testdata/local-esdt-and-nft.wasm") issuePrice := big.NewInt(1000) txData := []byte(issueFunc + "@" + hex.EncodeToString([]byte("TOKEN")) + @@ -543,7 +544,7 @@ func deployAndIssueNFTSFTThroughSC( time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, *nonce, *round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte("TKR"))) @@ -559,7 +560,7 @@ func deployAndIssueNFTSFTThroughSC( ) time.Sleep(time.Second) - *nonce, *round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, *nonce, *round, idxProposers) + *nonce, *round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, *nonce, *round) time.Sleep(time.Second) return scAddress, tokenIdentifier diff --git a/integrationTests/vm/esdt/process/esdtProcess_test.go b/integrationTests/vm/esdt/process/esdtProcess_test.go index 4c4b900b3f3..0321da07203 100644 --- a/integrationTests/vm/esdt/process/esdtProcess_test.go +++ b/integrationTests/vm/esdt/process/esdtProcess_test.go @@ -13,6 +13,10 @@ import ( "github.com/multiversx/mx-chain-core-go/data/esdt" "github.com/multiversx/mx-chain-core-go/data/smartContractResult" vmData "github.com/multiversx/mx-chain-core-go/data/vm" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + vmcommonBuiltInFunctions "github.com/multiversx/mx-chain-vm-common-go/builtInFunctions" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests" @@ -24,9 +28,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" "github.com/multiversx/mx-chain-go/vm" "github.com/multiversx/mx-chain-go/vm/systemSmartContracts" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - vmcommonBuiltInFunctions "github.com/multiversx/mx-chain-vm-common-go/builtInFunctions" - "github.com/stretchr/testify/require" ) func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { @@ -43,6 +44,7 @@ func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { OptimizeGasUsedInCrossMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( numOfShards, @@ -51,11 +53,11 @@ func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -82,7 +84,7 @@ func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -106,7 +108,7 @@ func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), vm.ESDTSCAddress, txData.ToString(), core.MinMetaTxExtraGasCost) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) finalSupply := initialSupply + mintValue @@ -131,7 +133,7 @@ func TestESDTIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtFrozenData := esdtCommon.GetESDTTokenData(t, nodes[1].OwnAccount.Address, nodes, []byte(tokenIdentifier), 0) @@ -175,6 +177,7 @@ func TestESDTCallBurnOnANonBurnableToken(t *testing.T) { ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: integrationTests.UnreachableEpoch, MultiClaimOnDelegationEnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -184,11 +187,11 @@ func TestESDTCallBurnOnANonBurnableToken(t *testing.T) { enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -219,7 +222,7 @@ func TestESDTCallBurnOnANonBurnableToken(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -233,7 +236,7 @@ func TestESDTCallBurnOnANonBurnableToken(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), node.OwnAccount.Address, txData.ToString(), integrationTests.AdditionalGasLimit) } - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) finalSupply := initialSupply @@ -250,7 +253,7 @@ func TestESDTCallBurnOnANonBurnableToken(t *testing.T) { time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtSCAcc := esdtCommon.GetUserAccountWithAddress(t, vm.ESDTSCAddress, nodes) @@ -279,11 +282,11 @@ func TestESDTIssueAndSelfTransferShouldNotChangeBalance(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -308,7 +311,7 @@ func TestESDTIssueAndSelfTransferShouldNotChangeBalance(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -322,7 +325,7 @@ func TestESDTIssueAndSelfTransferShouldNotChangeBalance(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), nodes[0].OwnAccount.Address, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, nodes[0].OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply) @@ -398,11 +401,11 @@ func TestScSendsEsdtToUserWithMessage(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -428,7 +431,7 @@ func TestScSendsEsdtToUserWithMessage(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -448,7 +451,7 @@ func TestScSendsEsdtToUserWithMessage(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(vaultScAddress) require.Nil(t, err) @@ -461,7 +464,7 @@ func TestScSendsEsdtToUserWithMessage(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), vaultScAddress, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-valueToSendToSc) @@ -473,7 +476,7 @@ func TestScSendsEsdtToUserWithMessage(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), vaultScAddress, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-valueToSendToSc+valueToRequest) @@ -495,11 +498,11 @@ func TestESDTcallsSC(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -526,7 +529,7 @@ func TestESDTcallsSC(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -541,7 +544,7 @@ func TestESDTcallsSC(t *testing.T) { } time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) numNodesWithoutIssuer := int64(len(nodes) - 1) @@ -567,7 +570,7 @@ func TestESDTcallsSC(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(scAddress) require.Nil(t, err) @@ -579,7 +582,7 @@ func TestESDTcallsSC(t *testing.T) { } time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) scQuery1 := &process.SCQuery{ @@ -613,11 +616,11 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -643,7 +646,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -663,7 +666,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(vault) require.Nil(t, err) @@ -679,7 +682,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(forwarder) require.Nil(t, err) @@ -692,7 +695,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), forwarder, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIssuerBalance := initialSupply - valueToSendToSc @@ -711,7 +714,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), forwarder, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, tokenIssuerBalance) @@ -735,7 +738,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), forwarder, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(5 * time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) time.Sleep(5 * time.Second) tokenIssuerBalance -= valueToTransferWithExecSc @@ -750,7 +753,7 @@ func TestScCallsScWithEsdtIntraShard(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), forwarder, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(5 * time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) time.Sleep(5 * time.Second) tokenIssuerBalance -= valueToTransferWithExecSc @@ -774,11 +777,11 @@ func TestCallbackPaymentEgld(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -804,7 +807,7 @@ func TestCallbackPaymentEgld(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -824,7 +827,7 @@ func TestCallbackPaymentEgld(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(secondScAddress) require.Nil(t, err) @@ -840,7 +843,7 @@ func TestCallbackPaymentEgld(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(forwarder) require.Nil(t, err) @@ -851,7 +854,7 @@ func TestCallbackPaymentEgld(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(valueToSendToSc), forwarder, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) time.Sleep(time.Second) esdtCommon.CheckNumCallBacks(t, forwarder, nodes, 1) @@ -864,7 +867,7 @@ func TestCallbackPaymentEgld(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), forwarder, txData.ToString(), integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) time.Sleep(time.Second) esdtCommon.CheckNumCallBacks(t, forwarder, nodes, 2) @@ -893,11 +896,11 @@ func TestScCallsScWithEsdtIntraShard_SecondScRefusesPayment(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -923,7 +926,7 @@ func TestScCallsScWithEsdtIntraShard_SecondScRefusesPayment(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -944,7 +947,7 @@ func TestScCallsScWithEsdtIntraShard_SecondScRefusesPayment(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(secondScAddress) require.Nil(t, err) @@ -962,12 +965,12 @@ func TestScCallsScWithEsdtIntraShard_SecondScRefusesPayment(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(firstScAddress) require.Nil(t, err) - nonce, round = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, idxProposers, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejected", 2) - _, _ = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, idxProposers, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejectedWithTransferAndExecute", 2) + nonce, round = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, leaders, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejected", 2) + _, _ = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, leaders, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejectedWithTransferAndExecute", 2) } func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { @@ -985,11 +988,11 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1015,7 +1018,7 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -1035,7 +1038,7 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(callerScAddress) require.Nil(t, err) @@ -1052,7 +1055,7 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(receiverScAddress) require.Nil(t, err) @@ -1073,7 +1076,7 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(callerScAddress) require.Nil(t, err) @@ -1101,7 +1104,7 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // call caller sc with ESDTTransfer which will call the second sc with execute_on_dest_context @@ -1122,7 +1125,7 @@ func TestScACallsScBWithExecOnDestESDT_TxPending(t *testing.T) { ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-valueToTransfer) @@ -1151,11 +1154,11 @@ func TestScACallsScBWithExecOnDestScAPerformsAsyncCall_NoCallbackInScB(t *testin numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1194,7 +1197,7 @@ func TestScACallsScBWithExecOnDestScAPerformsAsyncCall_NoCallbackInScB(t *testin ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(callerScAddress) require.Nil(t, err) @@ -1214,7 +1217,7 @@ func TestScACallsScBWithExecOnDestScAPerformsAsyncCall_NoCallbackInScB(t *testin ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) // issue ESDT by calling exec on dest context on child contract @@ -1238,7 +1241,7 @@ func TestScACallsScBWithExecOnDestScAPerformsAsyncCall_NoCallbackInScB(t *testin nrRoundsToPropagateMultiShard := 12 time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenID := integrationTests.GetTokenIdentifier(nodes, []byte(ticker)) @@ -1285,6 +1288,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn GlobalMintBurnDisableEpoch: integrationTests.UnreachableEpoch, SCProcessorV2EnableEpoch: integrationTests.UnreachableEpoch, FailExecutionOnEveryAPIErrorEnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } arwenVersion := config.WasmVMVersionByEpoch{Version: "v1.4"} vmConfig := &config.VirtualMachineConfig{ @@ -1299,11 +1303,11 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn vmConfig, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1328,7 +1332,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -1348,7 +1352,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(mapperScAddress) require.Nil(t, err) @@ -1365,7 +1369,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(senderScAddress) require.Nil(t, err) @@ -1381,7 +1385,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(senderScAddress) require.Nil(t, err) @@ -1400,7 +1404,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 12, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 12, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(receiverScAddress) require.Nil(t, err) @@ -1415,7 +1419,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) time.Sleep(time.Second) issueCost := big.NewInt(1000) @@ -1430,7 +1434,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn ) nrRoundsToPropagateMultiShard = 25 time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) scQuery := nodes[0].SCQueryService @@ -1457,7 +1461,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) valueToTransfer := int64(1000) @@ -1475,7 +1479,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithIntermediaryExecOnDest_NotEn integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-valueToTransfer) @@ -1501,11 +1505,11 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1534,7 +1538,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -1543,7 +1547,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te esdtCommon.IssueTestToken(nodes, initialSupplyWEGLD, tickerWEGLD) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifierWEGLD := string(integrationTests.GetTokenIdentifier(nodes, []byte(tickerWEGLD))) @@ -1563,7 +1567,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(mapperScAddress) require.Nil(t, err) @@ -1580,7 +1584,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(senderScAddress) require.Nil(t, err) @@ -1596,7 +1600,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(senderScAddress) require.Nil(t, err) @@ -1615,7 +1619,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 12, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 12, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(receiverScAddress) require.Nil(t, err) @@ -1634,12 +1638,12 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 12, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 12, nonce, round) _, err = nodes[0].AccntState.GetExistingAccount(receiverScAddressWEGLD) require.Nil(t, err) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) time.Sleep(time.Second) issueCost := big.NewInt(1000) @@ -1654,7 +1658,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) nrRoundsToPropagateMultiShard = 100 time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData.Clear().Func("issue").Str(ticker).Str(tokenIdentifier).Str("B") @@ -1668,7 +1672,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) nrRoundsToPropagateMultiShard = 100 time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData.Clear().Func("issue").Str(tickerWEGLD).Str(tokenIdentifierWEGLD).Str("L") @@ -1682,7 +1686,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) nrRoundsToPropagateMultiShard = 25 time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData.Clear().Func("issue").Str(tickerWEGLD).Str(tokenIdentifierWEGLD).Str("B") @@ -1696,7 +1700,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) nrRoundsToPropagateMultiShard = 25 time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData.Clear().Func("setTicker").Str(tokenIdentifier).Str(string(receiverScAddress)) @@ -1710,7 +1714,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 400, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 400, nonce, round) time.Sleep(time.Second) txData.Clear().Func("setTicker").Str(tokenIdentifierWEGLD).Str(string(receiverScAddressWEGLD)) @@ -1761,7 +1765,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData.Clear().Func("setBorrowTokenRoles").Int(3).Int(4).Int(5) @@ -1813,7 +1817,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) txData.Clear().Func("setBorrowTokenRoles").Int(3).Int(4).Int(5) @@ -1828,7 +1832,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te // time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) valueToTransfer := int64(1000) @@ -1846,7 +1850,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 40, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 40, nonce, round) time.Sleep(time.Second) valueToTransferWEGLD := int64(1000) @@ -1865,7 +1869,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 40, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 40, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-valueToTransfer) @@ -1883,7 +1887,7 @@ func TestExecOnDestWithTokenTransferFromScAtoScBWithScCall_GasUsedMismatch(t *te integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 25, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 25, nonce, round) time.Sleep(time.Second) esdtBorrowBUSDData := esdtCommon.GetESDTTokenData(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdStrBorrow), 0) @@ -1906,11 +1910,11 @@ func TestIssueESDT_FromSCWithNotEnoughGas(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -1938,7 +1942,7 @@ func TestIssueESDT_FromSCWithNotEnoughGas(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../testdata/local-esdt-and-nft.wasm") + scAddress := esdtCommon.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../testdata/local-esdt-and-nft.wasm") alice := nodes[0] issuePrice := big.NewInt(1000) @@ -1954,14 +1958,14 @@ func TestIssueESDT_FromSCWithNotEnoughGas(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) userAccount := esdtCommon.GetUserAccountWithAddress(t, alice.OwnAccount.Address, nodes) balanceAfterTransfer := userAccount.GetBalance() nrRoundsToPropagateMultiShard := 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) userAccount = esdtCommon.GetUserAccountWithAddress(t, alice.OwnAccount.Address, nodes) require.Equal(t, userAccount.GetBalance(), big.NewInt(0).Add(balanceAfterTransfer, issuePrice)) @@ -1982,6 +1986,7 @@ func TestIssueAndBurnESDT_MaxGasPerBlockExceeded(t *testing.T) { enableEpochs := config.EnableEpochs{ GlobalMintBurnDisableEpoch: integrationTests.UnreachableEpoch, MaxBlockchainHookCountersEnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( numOfShards, @@ -1990,11 +1995,11 @@ func TestIssueAndBurnESDT_MaxGasPerBlockExceeded(t *testing.T) { enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -2032,7 +2037,7 @@ func TestIssueAndBurnESDT_MaxGasPerBlockExceeded(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -2065,7 +2070,7 @@ func TestIssueAndBurnESDT_MaxGasPerBlockExceeded(t *testing.T) { } time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 25, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 25, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-int64(numBurns)) @@ -2106,11 +2111,11 @@ func TestScCallsScWithEsdtCrossShard_SecondScRefusesPayment(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -2137,7 +2142,7 @@ func TestScCallsScWithEsdtCrossShard_SecondScRefusesPayment(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -2158,7 +2163,7 @@ func TestScCallsScWithEsdtCrossShard_SecondScRefusesPayment(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := nodes[0].AccntState.GetExistingAccount(secondScAddress) require.Nil(t, err) @@ -2175,12 +2180,12 @@ func TestScCallsScWithEsdtCrossShard_SecondScRefusesPayment(t *testing.T) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err = nodes[2].AccntState.GetExistingAccount(firstScAddress) require.Nil(t, err) - nonce, round = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, idxProposers, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejected", 20) - _, _ = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, idxProposers, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejectedWithTransferAndExecute", 20) + nonce, round = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, leaders, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejected", 20) + _, _ = transferRejectedBySecondContract(t, nonce, round, nodes, tokenIssuer, leaders, initialSupply, tokenIdentifier, firstScAddress, secondScAddress, "transferToSecondContractRejectedWithTransferAndExecute", 20) } func transferRejectedBySecondContract( @@ -2188,7 +2193,7 @@ func transferRejectedBySecondContract( nonce, round uint64, nodes []*integrationTests.TestProcessorNode, tokenIssuer *integrationTests.TestProcessorNode, - idxProposers []int, + leaders []*integrationTests.TestProcessorNode, initialSupply int64, tokenIdentifier string, firstScAddress []byte, @@ -2210,7 +2215,7 @@ func transferRejectedBySecondContract( integrationTests.AdditionalGasLimit) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundToPropagate, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundToPropagate, nonce, round) time.Sleep(time.Second) esdtCommon.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply-valueToSendToSc) @@ -2250,11 +2255,11 @@ func multiTransferFromSC(t *testing.T, numOfShards int) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -2297,7 +2302,7 @@ func multiTransferFromSC(t *testing.T, numOfShards int) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := integrationTests.GetTokenIdentifier(nodes, []byte(ticker)) @@ -2319,7 +2324,7 @@ func multiTransferFromSC(t *testing.T, numOfShards int) { integrationTests.AdditionalGasLimit, ) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 4, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 4, nonce, round) _, err := ownerNode.AccntState.GetExistingAccount(scAddress) require.Nil(t, err) @@ -2327,7 +2332,7 @@ func multiTransferFromSC(t *testing.T, numOfShards int) { []byte(core.ESDTRoleLocalMint), } esdtCommon.SetRoles(nodes, scAddress, tokenIdentifier, roles) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 12, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 12, nonce, round) txData := txDataBuilder.NewBuilder() txData.Func("batchTransferEsdtToken") @@ -2349,7 +2354,7 @@ func multiTransferFromSC(t *testing.T, numOfShards int) { integrationTests.AdditionalGasLimit, ) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, 12, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 12, nonce, round) esdtCommon.CheckAddressHasTokens(t, destinationNode.OwnAccount.Address, nodes, tokenIdentifier, 0, 20) } @@ -2366,6 +2371,7 @@ func TestESDTIssueUnderProtectedKeyWillReturnTokensBack(t *testing.T) { OptimizeGasUsedInCrossMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, MiniBlockPartialExecutionEnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -2375,11 +2381,11 @@ func TestESDTIssueUnderProtectedKeyWillReturnTokensBack(t *testing.T) { enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -2406,14 +2412,14 @@ func TestESDTIssueUnderProtectedKeyWillReturnTokensBack(t *testing.T) { time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 1, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 1, nonce, round) time.Sleep(time.Second) userAcc := esdtCommon.GetUserAccountWithAddress(t, tokenIssuer.OwnAccount.Address, nodes) balanceBefore := userAcc.GetBalance() nrRoundsToPropagateMultiShard := 12 - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) tokenIdentifier := integrationTests.GetTokenIdentifier(nodes, []byte(ticker)) require.Equal(t, 0, len(tokenIdentifier)) diff --git a/integrationTests/vm/esdt/roles/esdtRoles_test.go b/integrationTests/vm/esdt/roles/esdtRoles_test.go index 5c117ed4edd..7601633ddf5 100644 --- a/integrationTests/vm/esdt/roles/esdtRoles_test.go +++ b/integrationTests/vm/esdt/roles/esdtRoles_test.go @@ -7,12 +7,13 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm/esdt" "github.com/multiversx/mx-chain-go/testscommon/txDataBuilder" "github.com/multiversx/mx-chain-go/vm" - "github.com/stretchr/testify/require" ) // Test scenario @@ -35,11 +36,11 @@ func TestESDTRolesIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -65,7 +66,7 @@ func TestESDTRolesIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 6 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte("FTT"))) @@ -75,7 +76,7 @@ func TestESDTRolesIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { setRole(nodes, nodes[0].OwnAccount.Address, []byte(tokenIdentifier), []byte(core.ESDTRoleLocalBurn)) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdt.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply.Int64()) @@ -93,7 +94,7 @@ func TestESDTRolesIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // check balance ofter local mint @@ -112,7 +113,7 @@ func TestESDTRolesIssueAndTransactionsOnMultiShardEnvironment(t *testing.T) { ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // check balance ofter local mint @@ -141,11 +142,11 @@ func TestESDTRolesSetRolesAndUnsetRolesIssueAndTransactionsOnMultiShardEnvironme numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -171,7 +172,7 @@ func TestESDTRolesSetRolesAndUnsetRolesIssueAndTransactionsOnMultiShardEnvironme time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte("FTT"))) @@ -180,14 +181,14 @@ func TestESDTRolesSetRolesAndUnsetRolesIssueAndTransactionsOnMultiShardEnvironme setRole(nodes, nodes[0].OwnAccount.Address, []byte(tokenIdentifier), []byte(core.ESDTRoleLocalMint)) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // unset special role unsetRole(nodes, nodes[0].OwnAccount.Address, []byte(tokenIdentifier), []byte(core.ESDTRoleLocalMint)) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) esdt.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply.Int64()) @@ -207,7 +208,7 @@ func TestESDTRolesSetRolesAndUnsetRolesIssueAndTransactionsOnMultiShardEnvironme time.Sleep(time.Second) nrRoundsToPropagateMultiShard = 7 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // check balance ofter local mint @@ -215,7 +216,7 @@ func TestESDTRolesSetRolesAndUnsetRolesIssueAndTransactionsOnMultiShardEnvironme setRole(nodes, nodes[0].OwnAccount.Address, []byte(tokenIdentifier), []byte(core.ESDTRoleLocalBurn)) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // burn local tokens @@ -231,7 +232,7 @@ func TestESDTRolesSetRolesAndUnsetRolesIssueAndTransactionsOnMultiShardEnvironme ) time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // check balance ofter local mint @@ -273,11 +274,11 @@ func TestESDTMintTransferAndExecute(t *testing.T) { numMetachainNodes, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -295,7 +296,7 @@ func TestESDTMintTransferAndExecute(t *testing.T) { round = integrationTests.IncrementAndPrintRound(round) nonce++ - scAddress := esdt.DeployNonPayableSmartContract(t, nodes, idxProposers, &nonce, &round, "../testdata/egld-esdt-swap.wasm") + scAddress := esdt.DeployNonPayableSmartContract(t, nodes, leaders, &nonce, &round, "../testdata/egld-esdt-swap.wasm") // issue ESDT by calling exec on dest context on child contract ticker := "DSN" @@ -316,7 +317,7 @@ func TestESDTMintTransferAndExecute(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 15 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := integrationTests.GetTokenIdentifier(nodes, []byte(ticker)) @@ -329,7 +330,7 @@ func TestESDTMintTransferAndExecute(t *testing.T) { integrationTests.AdditionalGasLimit, ) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) valueToWrap := big.NewInt(1000) @@ -346,7 +347,7 @@ func TestESDTMintTransferAndExecute(t *testing.T) { } time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) for i, n := range nodes { @@ -370,7 +371,7 @@ func TestESDTMintTransferAndExecute(t *testing.T) { } time.Sleep(time.Second) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) userAccount := esdt.GetUserAccountWithAddress(t, scAddress, nodes) @@ -388,6 +389,7 @@ func TestESDTLocalBurnFromAnyoneOfThisToken(t *testing.T) { enableEpochs := config.EnableEpochs{ ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( numOfShards, @@ -396,11 +398,11 @@ func TestESDTLocalBurnFromAnyoneOfThisToken(t *testing.T) { enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -431,7 +433,7 @@ func TestESDTLocalBurnFromAnyoneOfThisToken(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -445,7 +447,7 @@ func TestESDTLocalBurnFromAnyoneOfThisToken(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), node.OwnAccount.Address, txData.ToString(), integrationTests.AdditionalGasLimit) } - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) finalSupply := initialSupply @@ -460,7 +462,7 @@ func TestESDTLocalBurnFromAnyoneOfThisToken(t *testing.T) { txData.Clear().LocalBurnESDT(tokenIdentifier, finalSupply) integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), tokenIssuer.OwnAccount.Address, txData.ToString(), integrationTests.AdditionalGasLimit) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) for _, node := range nodes { @@ -479,6 +481,7 @@ func TestESDTWithTransferRoleCrossShardShouldWork(t *testing.T) { enableEpochs := config.EnableEpochs{ ScheduledMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( numOfShards, @@ -487,11 +490,11 @@ func TestESDTWithTransferRoleCrossShardShouldWork(t *testing.T) { enableEpochs, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -522,7 +525,7 @@ func TestESDTWithTransferRoleCrossShardShouldWork(t *testing.T) { time.Sleep(time.Second) nrRoundsToPropagateMultiShard := 12 - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) tokenIdentifier := string(integrationTests.GetTokenIdentifier(nodes, []byte(ticker))) @@ -530,7 +533,7 @@ func TestESDTWithTransferRoleCrossShardShouldWork(t *testing.T) { esdt.CheckAddressHasTokens(t, tokenIssuer.OwnAccount.Address, nodes, []byte(tokenIdentifier), 0, initialSupply) time.Sleep(time.Second) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 2, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 2, nonce, round) time.Sleep(time.Second) // send tx to other nodes @@ -540,7 +543,7 @@ func TestESDTWithTransferRoleCrossShardShouldWork(t *testing.T) { integrationTests.CreateAndSendTransaction(tokenIssuer, nodes, big.NewInt(0), node.OwnAccount.Address, txData.ToString(), integrationTests.AdditionalGasLimit) } - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) // send value back to the initial node @@ -550,7 +553,7 @@ func TestESDTWithTransferRoleCrossShardShouldWork(t *testing.T) { integrationTests.CreateAndSendTransaction(node, nodes, big.NewInt(0), tokenIssuer.OwnAccount.Address, txData.ToString(), integrationTests.AdditionalGasLimit) } - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) for _, node := range nodes[1:] { diff --git a/integrationTests/vm/staking/baseTestMetaProcessor.go b/integrationTests/vm/staking/baseTestMetaProcessor.go index 0ae2b5ed2d8..a1d5a36b82e 100644 --- a/integrationTests/vm/staking/baseTestMetaProcessor.go +++ b/integrationTests/vm/staking/baseTestMetaProcessor.go @@ -263,7 +263,7 @@ func (tmp *TestMetaProcessor) createNewHeader(t *testing.T, round uint64) *block round, currentHash, currentHeader.GetRandSeed(), - tmp.NodesCoordinator.ConsensusGroupSize(core.MetachainShardId), + tmp.NodesCoordinator.ConsensusGroupSizeForShardAndEpoch(core.MetachainShardId, 0), ) return header diff --git a/integrationTests/vm/staking/componentsHolderCreator.go b/integrationTests/vm/staking/componentsHolderCreator.go index e3673b08ec7..aa0a92c2115 100644 --- a/integrationTests/vm/staking/componentsHolderCreator.go +++ b/integrationTests/vm/staking/componentsHolderCreator.go @@ -11,9 +11,11 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters/uint64ByteSlice" "github.com/multiversx/mx-chain-core-go/hashing/sha256" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/enablers" "github.com/multiversx/mx-chain-go/common/forking" + "github.com/multiversx/mx-chain-go/common/graceperiod" "github.com/multiversx/mx-chain-go/common/statistics/disabled" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -69,10 +71,11 @@ func createCoreComponents() factory.CoreComponentsHolder { StakingV4Step3EnableEpoch: stakingV4Step3EnableEpoch, GovernanceEnableEpoch: integrationTests.UnreachableEpoch, RefactorPeersMiniBlocksEnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } enableEpochsHandler, _ := enablers.NewEnableEpochsHandler(configEnableEpochs, epochNotifier) - + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) return &integrationMocks.CoreComponentsStub{ InternalMarshalizerField: &marshal.GogoProtoMarshalizer{}, HasherField: sha256.NewSha256(), @@ -90,6 +93,7 @@ func createCoreComponents() factory.CoreComponentsHolder { EnableEpochsHandlerField: enableEpochsHandler, EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, RoundNotifierField: ¬ifierMocks.RoundNotifierStub{}, + EpochChangeGracePeriodHandlerField: gracePeriod, } } diff --git a/integrationTests/vm/staking/metaBlockProcessorCreator.go b/integrationTests/vm/staking/metaBlockProcessorCreator.go index 759458cf30e..cbb52101531 100644 --- a/integrationTests/vm/staking/metaBlockProcessorCreator.go +++ b/integrationTests/vm/staking/metaBlockProcessorCreator.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/epochStart/metachain" @@ -217,8 +218,9 @@ func createGenesisMetaBlock() *block.MetaBlock { func createHeaderValidator(coreComponents factory.CoreComponentsHolder) epochStart.HeaderValidator { argsHeaderValidator := blproc.ArgsHeaderValidator{ - Hasher: coreComponents.Hasher(), - Marshalizer: coreComponents.InternalMarshalizer(), + Hasher: coreComponents.Hasher(), + Marshalizer: coreComponents.InternalMarshalizer(), + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), } headerValidator, _ := blproc.NewHeaderValidator(argsHeaderValidator) return headerValidator diff --git a/integrationTests/vm/staking/nodesCoordiantorCreator.go b/integrationTests/vm/staking/nodesCoordiantorCreator.go index 27a54719521..698df48c408 100644 --- a/integrationTests/vm/staking/nodesCoordiantorCreator.go +++ b/integrationTests/vm/staking/nodesCoordiantorCreator.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state/accounts" "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" nodesSetupMock "github.com/multiversx/mx-chain-go/testscommon/genesisMocks" "github.com/multiversx/mx-chain-go/testscommon/stakingcommon" "github.com/multiversx/mx-chain-storage-go/lrucache" @@ -39,10 +40,6 @@ func createNodesCoordinator( maxNodesConfig []config.MaxNodesChangeConfig, ) nodesCoordinator.NodesCoordinator { shufflerArgs := &nodesCoordinator.NodesShufflerArgs{ - NodesShard: numOfEligibleNodesPerShard, - NodesMeta: numOfMetaNodes, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: maxNodesConfig, EnableEpochs: config.EnableEpochs{ @@ -51,11 +48,23 @@ func createNodesCoordinator( }, EnableEpochsHandler: coreComponents.EnableEpochsHandler(), } + nodeShuffler, _ := nodesCoordinator.NewHashValidatorsShuffler(shufflerArgs) cache, _ := lrucache.NewCache(10000) argumentsNodesCoordinator := nodesCoordinator.ArgNodesCoordinator{ - ShardConsensusGroupSize: shardConsensusGroupSize, - MetaConsensusGroupSize: metaConsensusGroupSize, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(epoch uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + RoundDuration: 0, + Hysteresis: hysteresis, + ShardConsensusGroupSize: uint32(shardConsensusGroupSize), + ShardMinNumNodes: numOfEligibleNodesPerShard, + MetachainConsensusGroupSize: uint32(metaConsensusGroupSize), + MetachainMinNumNodes: numOfMetaNodes, + Adaptivity: adaptivity, + }, nil + }, + }, Marshalizer: coreComponents.InternalMarshalizer(), Hasher: coreComponents.Hasher(), ShardIDAsObserver: core.MetachainShardId, diff --git a/integrationTests/vm/systemVM/stakingSC_test.go b/integrationTests/vm/systemVM/stakingSC_test.go index 75e958f926b..453e6a62264 100644 --- a/integrationTests/vm/systemVM/stakingSC_test.go +++ b/integrationTests/vm/systemVM/stakingSC_test.go @@ -10,16 +10,16 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/integrationTests" - "github.com/multiversx/mx-chain-go/integrationTests/multiShard/endOfEpoch" integrationTestsVm "github.com/multiversx/mx-chain-go/integrationTests/vm" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/accounts" "github.com/multiversx/mx-chain-go/vm" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestStakingUnstakingAndUnbondingOnMultiShardEnvironment(t *testing.T) { @@ -38,6 +38,7 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironment(t *testing.T) { StakingV4Step1EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step2EnableEpoch: integrationTests.UnreachableEpoch, StakingV4Step3EnableEpoch: integrationTests.UnreachableEpoch, + AndromedaEnableEpoch: integrationTests.UnreachableEpoch, } nodes := integrationTests.CreateNodesWithEnableEpochs( @@ -47,11 +48,11 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironment(t *testing.T) { enableEpochsConfig, ) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for i := 0; i < numOfShards; i++ { - idxProposers[i] = i * nodesPerShard + leaders[i] = nodes[i*nodesPerShard] } - idxProposers[numOfShards] = numOfShards * nodesPerShard + leaders[numOfShards] = nodes[numOfShards*nodesPerShard] integrationTests.DisplayAndStartNodes(nodes) @@ -87,7 +88,7 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironment(t *testing.T) { nrRoundsToPropagateMultiShard := 10 integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) @@ -109,11 +110,11 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironment(t *testing.T) { time.Sleep(time.Second) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) // ----- wait for unbond period integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) manualSetToInactiveStateStakedPeers(t, nodes) @@ -127,7 +128,7 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironment(t *testing.T) { time.Sleep(time.Second) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) verifyUnbound(t, nodes) } @@ -152,18 +153,16 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironmentWithValidatorStatist ) nodes := make([]*integrationTests.TestProcessorNode, 0) - idxProposers := make([]int, numOfShards+1) + leaders := make([]*integrationTests.TestProcessorNode, numOfShards+1) for _, nds := range nodesMap { nodes = append(nodes, nds...) } - for _, nds := range nodesMap { - idx, err := integrationTestsVm.GetNodeIndex(nodes, nds[0]) - require.Nil(t, err) - - idxProposers = append(idxProposers, idx) + for i := 0; i < numOfShards; i++ { + leaders[i] = nodesMap[uint32(i)][0] } + leaders[numOfShards] = nodesMap[core.MetachainShardId][0] integrationTests.DisplayAndStartNodes(nodes) @@ -203,7 +202,7 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironmentWithValidatorStatist nrRoundsToPropagateMultiShard := 10 integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) time.Sleep(time.Second) @@ -227,7 +226,7 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironmentWithValidatorStatist time.Sleep(time.Second) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) roundsPerEpoch := uint64(10) for _, node := range nodes { @@ -237,7 +236,7 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironmentWithValidatorStatist // ----- wait for unbound period integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - nonce, round = integrationTests.WaitOperationToBeDone(t, nodes, 10, nonce, round, idxProposers) + nonce, round = integrationTests.WaitOperationToBeDone(t, leaders, nodes, 10, nonce, round) // ----- send unBound for index, node := range nodes { @@ -252,7 +251,7 @@ func TestStakingUnstakingAndUnbondingOnMultiShardEnvironmentWithValidatorStatist time.Sleep(time.Second) integrationTests.AddSelfNotarizedHeaderByMetachain(nodes) - _, _ = integrationTests.WaitOperationToBeDone(t, nodes, nrRoundsToPropagateMultiShard, nonce, round, idxProposers) + _, _ = integrationTests.WaitOperationToBeDone(t, leaders, nodes, nrRoundsToPropagateMultiShard, nonce, round) verifyUnbound(t, nodes) } @@ -322,7 +321,6 @@ func TestStakeWithRewardsAddressAndValidatorStatistics(t *testing.T) { } nbBlocksToProduce := roundsPerEpoch * 3 - var consensusNodes map[uint32][]*integrationTests.TestProcessorNode for i := uint64(0); i < nbBlocksToProduce; i++ { for _, nodesSlice := range nodesMap { @@ -330,9 +328,8 @@ func TestStakeWithRewardsAddressAndValidatorStatistics(t *testing.T) { integrationTests.AddSelfNotarizedHeaderByMetachain(nodesSlice) } - _, _, consensusNodes = integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) - indexesProposers := endOfEpoch.GetBlockProposersIndexes(consensusNodes, nodesMap) - integrationTests.SyncAllShardsWithRoundBlock(t, nodesMap, indexesProposers, round) + proposeData := integrationTests.AllShardsProposeBlock(round, nonce, nodesMap) + integrationTests.SyncAllShardsWithRoundBlock(t, proposeData, nodesMap, round) round++ nonce++ diff --git a/integrationTests/vm/wasm/queries/queries_test.go b/integrationTests/vm/wasm/queries/queries_test.go index e83170e6e0b..c88be80b43b 100644 --- a/integrationTests/vm/wasm/queries/queries_test.go +++ b/integrationTests/vm/wasm/queries/queries_test.go @@ -6,6 +6,7 @@ import ( "fmt" "math/big" "testing" + "time" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/integrationTests" @@ -132,6 +133,9 @@ func deploy(t *testing.T, network *integrationTests.MiniNetwork, sender []byte, ) require.NoError(t, err) + // Allow the transaction to reach the mempool. + time.Sleep(100 * time.Millisecond) + scAddress, _ := network.ShardNode.BlockchainHook.NewAddress(sender, 0, factory.WasmVirtualMachine) return scAddress } @@ -148,6 +152,9 @@ func setState(t *testing.T, network *integrationTests.MiniNetwork, scAddress []b ) require.NoError(t, err) + + // Allow the transaction to reach the mempool. + time.Sleep(100 * time.Millisecond) } func getState(t *testing.T, node *integrationTests.TestProcessorNode, scAddress []byte, blockNonce core.OptionalUint64) int { diff --git a/integrationTests/vm/wasm/upgrades/upgrades_test.go b/integrationTests/vm/wasm/upgrades/upgrades_test.go index c6313d65e73..09b3c0ee49c 100644 --- a/integrationTests/vm/wasm/upgrades/upgrades_test.go +++ b/integrationTests/vm/wasm/upgrades/upgrades_test.go @@ -6,12 +6,13 @@ import ( "math/big" "testing" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/integrationTests" "github.com/multiversx/mx-chain-go/integrationTests/vm/wasm" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/require" ) func TestUpgrades_Hello(t *testing.T) { @@ -212,7 +213,7 @@ func TestUpgrades_HelloTrialAndError(t *testing.T) { require.Nil(t, err) scAddress, _ := network.ShardNode.BlockchainHook.NewAddress(alice.Address, 0, factory.WasmVirtualMachine) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{24}, query(t, network.ShardNode, scAddress, "getUltimateAnswer")) // Upgrade as Bob - upgrade should fail, since Alice is the owner @@ -225,7 +226,7 @@ func TestUpgrades_HelloTrialAndError(t *testing.T) { ) require.Nil(t, err) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{24}, query(t, network.ShardNode, scAddress, "getUltimateAnswer")) // Now upgrade as Alice, should work @@ -238,7 +239,7 @@ func TestUpgrades_HelloTrialAndError(t *testing.T) { ) require.Nil(t, err) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{42}, query(t, network.ShardNode, scAddress, "getUltimateAnswer")) } @@ -269,7 +270,7 @@ func TestUpgrades_CounterTrialAndError(t *testing.T) { require.Nil(t, err) scAddress, _ := network.ShardNode.BlockchainHook.NewAddress(alice.Address, 0, factory.WasmVirtualMachine) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{1}, query(t, network.ShardNode, scAddress, "get")) // Increment the counter (could be either Bob or Alice) @@ -282,7 +283,7 @@ func TestUpgrades_CounterTrialAndError(t *testing.T) { ) require.Nil(t, err) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{2}, query(t, network.ShardNode, scAddress, "get")) // Upgrade as Bob - upgrade should fail, since Alice is the owner (counter.init() not executed, state not reset) @@ -295,7 +296,7 @@ func TestUpgrades_CounterTrialAndError(t *testing.T) { ) require.Nil(t, err) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{2}, query(t, network.ShardNode, scAddress, "get")) // Now upgrade as Alice, should work (state is reset by counter.init()) @@ -308,7 +309,7 @@ func TestUpgrades_CounterTrialAndError(t *testing.T) { ) require.Nil(t, err) - network.Continue(t, 1) + network.Continue(t, 2) require.Equal(t, []byte{1}, query(t, network.ShardNode, scAddress, "get")) } diff --git a/integrationTests/vm/wasm/wasmvm/testRunner.go b/integrationTests/vm/wasm/wasmvm/testRunner.go index e6756b1a4c2..f5384189d16 100644 --- a/integrationTests/vm/wasm/wasmvm/testRunner.go +++ b/integrationTests/vm/wasm/wasmvm/testRunner.go @@ -3,6 +3,7 @@ package wasmvm import ( "crypto/rand" "encoding/hex" + "errors" "fmt" "math/big" "time" @@ -171,7 +172,7 @@ func DeployAndExecuteERC20WithBigInt( return nil, err } if returnCode != vmcommon.Ok { - return nil, fmt.Errorf(returnCode.String()) + return nil, errors.New(returnCode.String()) } ownerNonce++ @@ -263,7 +264,7 @@ func SetupERC20Test( return err } if returnCode != vmcommon.Ok { - return fmt.Errorf(returnCode.String()) + return errors.New(returnCode.String()) } testContext.ContractOwner.Nonce++ diff --git a/keysManagement/managedPeersHolder.go b/keysManagement/managedPeersHolder.go index 8156b64c8eb..39f80f6bbaf 100644 --- a/keysManagement/managedPeersHolder.go +++ b/keysManagement/managedPeersHolder.go @@ -12,10 +12,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" crypto "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/redundancy/common" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("keysManagement") diff --git a/node/chainSimulator/chainSimulator.go b/node/chainSimulator/chainSimulator.go index 72d21b33fe2..288e7d25a23 100644 --- a/node/chainSimulator/chainSimulator.go +++ b/node/chainSimulator/chainSimulator.go @@ -6,20 +6,11 @@ import ( "encoding/hex" "errors" "fmt" + "math/big" "sync" "time" - "github.com/multiversx/mx-chain-go/config" - "github.com/multiversx/mx-chain-go/factory" - "github.com/multiversx/mx-chain-go/node/chainSimulator/components" - "github.com/multiversx/mx-chain-go/node/chainSimulator/components/heartbeat" - "github.com/multiversx/mx-chain-go/node/chainSimulator/configs" - "github.com/multiversx/mx-chain-go/node/chainSimulator/dtos" - chainSimulatorErrors "github.com/multiversx/mx-chain-go/node/chainSimulator/errors" - "github.com/multiversx/mx-chain-go/node/chainSimulator/process" - mxChainSharding "github.com/multiversx/mx-chain-go/sharding" - "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/sharding" @@ -31,6 +22,14 @@ import ( crypto "github.com/multiversx/mx-chain-crypto-go" "github.com/multiversx/mx-chain-crypto-go/signing" "github.com/multiversx/mx-chain-crypto-go/signing/mcl" + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/factory" + "github.com/multiversx/mx-chain-go/node/chainSimulator/components" + "github.com/multiversx/mx-chain-go/node/chainSimulator/components/heartbeat" + "github.com/multiversx/mx-chain-go/node/chainSimulator/configs" + "github.com/multiversx/mx-chain-go/node/chainSimulator/dtos" + chainSimulatorErrors "github.com/multiversx/mx-chain-go/node/chainSimulator/errors" + "github.com/multiversx/mx-chain-go/node/chainSimulator/process" logger "github.com/multiversx/mx-chain-logger-go" ) @@ -52,6 +51,7 @@ type ArgsChainSimulator struct { NumOfShards uint32 MinNodesPerShard uint32 MetaChainMinNodes uint32 + Hysteresis float32 NumNodesWaitingListShard uint32 NumNodesWaitingListMeta uint32 GenesisTimestamp int64 @@ -88,8 +88,8 @@ type simulator struct { func NewChainSimulator(args ArgsChainSimulator) (*simulator, error) { return NewBaseChainSimulator(ArgsBaseChainSimulator{ ArgsChainSimulator: args, - ConsensusGroupSize: configs.ChainSimulatorConsensusGroupSize, - MetaChainConsensusGroupSize: configs.ChainSimulatorConsensusGroupSize, + ConsensusGroupSize: args.MinNodesPerShard, + MetaChainConsensusGroupSize: args.MetaChainMinNodes, }) } @@ -124,6 +124,7 @@ func (s *simulator) createChainHandlers(args ArgsBaseChainSimulator) error { ConsensusGroupSize: args.ConsensusGroupSize, MetaChainMinNodes: args.MetaChainMinNodes, MetaChainConsensusGroupSize: args.MetaChainConsensusGroupSize, + Hysteresis: args.Hysteresis, RoundsPerEpoch: args.RoundsPerEpoch, InitialEpoch: args.InitialEpoch, AlterConfigsFunction: args.AlterConfigsFunction, @@ -208,6 +209,9 @@ func (s *simulator) createChainHandlers(args ArgsBaseChainSimulator) error { s.initialWalletKeys = outputConfigs.InitialWallets s.validatorsPrivateKeys = outputConfigs.ValidatorsPrivateKeys + s.addProofs() + s.setBasePeerIds() + log.Info("running the chain simulator with the following parameters", "number of shards (including meta)", args.NumOfShards+1, "round per epoch", outputConfigs.Configs.GeneralConfig.EpochStartConfig.RoundsPerEpoch, @@ -219,6 +223,39 @@ func (s *simulator) createChainHandlers(args ArgsBaseChainSimulator) error { return nil } +func (s *simulator) setBasePeerIds() { + peerIds := make(map[uint32]core.PeerID, 0) + for _, nodeHandler := range s.nodes { + peerID := nodeHandler.GetNetworkComponents().NetworkMessenger().ID() + peerIds[nodeHandler.GetShardCoordinator().SelfId()] = peerID + } + + for _, nodeHandler := range s.nodes { + nodeHandler.SetBasePeers(peerIds) + } +} + +func (s *simulator) addProofs() { + proofs := make([]*block.HeaderProof, 0, len(s.nodes)) + + for shardID, nodeHandler := range s.nodes { + hash := nodeHandler.GetChainHandler().GetGenesisHeaderHash() + proofs = append(proofs, &block.HeaderProof{ + HeaderShardId: shardID, + HeaderHash: hash, + }) + } + + metachainProofsPool := s.GetNodeHandler(core.MetachainShardId).GetDataComponents().Datapool().Proofs() + for _, proof := range proofs { + _ = metachainProofsPool.AddProof(proof) + + if proof.HeaderShardId != core.MetachainShardId { + _ = s.GetNodeHandler(proof.HeaderShardId).GetDataComponents().Datapool().Proofs().AddProof(proof) + } + } +} + func computeStartTimeBaseOnInitialRound(args ArgsChainSimulator) int64 { return args.GenesisTimestamp + int64(args.RoundDurationInMillis/1000)*args.InitialRound } @@ -341,7 +378,14 @@ func (s *simulator) ForceChangeOfEpoch() error { epoch := s.nodes[core.MetachainShardId].GetProcessComponents().EpochStartTrigger().Epoch() s.mutex.Unlock() - return s.GenerateBlocksUntilEpochIsReached(int32(epoch + 1)) + err := s.GenerateBlocksUntilEpochIsReached(int32(epoch + 1)) + if err != nil { + return err + } + + s.incrementRoundOnAllValidators() + + return s.allNodesCreateBlocks() } func (s *simulator) allNodesCreateBlocks() error { @@ -402,43 +446,46 @@ func (s *simulator) AddValidatorKeys(validatorsPrivateKeys [][]byte) error { // GenerateAndMintWalletAddress will generate an address in the provided shard and will mint that address with the provided value // if the target shard ID value does not correspond to a node handled by the chain simulator, the address will be generated in a random shard ID func (s *simulator) GenerateAndMintWalletAddress(targetShardID uint32, value *big.Int) (dtos.WalletAddress, error) { - addressConverter := s.nodes[core.MetachainShardId].GetCoreComponents().AddressPubKeyConverter() - nodeHandler := s.GetNodeHandler(targetShardID) - var buff []byte - if check.IfNil(nodeHandler) { - buff = generateAddress(addressConverter.Len()) - } else { - buff = generateAddressInShard(nodeHandler.GetShardCoordinator(), addressConverter.Len()) - } - - address, err := addressConverter.Encode(buff) - if err != nil { - return dtos.WalletAddress{}, err - } + wallet := s.GenerateAddressInShard(targetShardID) - err = s.SetStateMultiple([]*dtos.AddressState{ + err := s.SetStateMultiple([]*dtos.AddressState{ { - Address: address, + Address: wallet.Bech32, Balance: value.String(), }, }) - return dtos.WalletAddress{ - Bech32: address, - Bytes: buff, - }, err + return wallet, err } -func generateAddressInShard(shardCoordinator mxChainSharding.Coordinator, len int) []byte { +// GenerateAddressInShard will generate a wallet address based on the provided shard +func (s *simulator) GenerateAddressInShard(providedShardID uint32) dtos.WalletAddress { + converter := s.nodes[core.MetachainShardId].GetCoreComponents().AddressPubKeyConverter() + nodeHandler := s.GetNodeHandler(providedShardID) + if check.IfNil(nodeHandler) { + return generateWalletAddress(converter) + } + for { - buff := generateAddress(len) - shardID := shardCoordinator.ComputeId(buff) - if shardID == shardCoordinator.SelfId() { - return buff + buff := generateAddress(converter.Len()) + if nodeHandler.GetShardCoordinator().ComputeId(buff) == providedShardID { + return generateWalletAddressFromBuffer(converter, buff) } } } +func generateWalletAddress(converter core.PubkeyConverter) dtos.WalletAddress { + buff := generateAddress(converter.Len()) + return generateWalletAddressFromBuffer(converter, buff) +} + +func generateWalletAddressFromBuffer(converter core.PubkeyConverter, buff []byte) dtos.WalletAddress { + return dtos.WalletAddress{ + Bech32: converter.SilentEncode(buff, log), + Bytes: buff, + } +} + func generateAddress(len int) []byte { buff := make([]byte, len) _, _ = rand.Read(buff) @@ -622,7 +669,7 @@ func (s *simulator) computeTransactionsStatus(txsWithResult []*transactionWithRe result, errGet := s.GetNodeHandler(destinationShardID).GetFacadeHandler().GetTransaction(resultTx.hexHash, true) if errGet == nil && result.Status != transaction.TxStatusPending { - log.Info("############## transaction was executed ##############", "txHash", resultTx.hexHash) + log.Trace("############## transaction was executed ##############", "txHash", resultTx.hexHash) resultTx.result = result continue } @@ -664,7 +711,7 @@ func (s *simulator) sendTx(tx *transaction.Transaction) (string, error) { for { recoveredTx, _ := node.GetFacadeHandler().GetTransaction(txHashHex, false) if recoveredTx != nil { - log.Info("############## send transaction ##############", "txHash", txHashHex) + log.Trace("############## send transaction ##############", "txHash", txHashHex) return txHashHex, nil } diff --git a/node/chainSimulator/chainSimulator_test.go b/node/chainSimulator/chainSimulator_test.go index 2ba89205afe..d5e8d7baff8 100644 --- a/node/chainSimulator/chainSimulator_test.go +++ b/node/chainSimulator/chainSimulator_test.go @@ -1,7 +1,6 @@ package chainSimulator import ( - "github.com/multiversx/mx-chain-go/errors" "math/big" "strings" "testing" @@ -10,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/errors" chainSimulatorCommon "github.com/multiversx/mx-chain-go/integrationTests/chainSimulator" "github.com/multiversx/mx-chain-go/node/chainSimulator/components/api" "github.com/multiversx/mx-chain-go/node/chainSimulator/configs" @@ -36,14 +36,25 @@ func TestNewChainSimulator(t *testing.T) { NumOfShards: 3, GenesisTimestamp: startTime, RoundDurationInMillis: roundDurationInMillis, - RoundsPerEpoch: core.OptionalUint64{}, - ApiInterface: api.NewNoApiInterface(), - MinNodesPerShard: 1, - MetaChainMinNodes: 1, + RoundsPerEpoch: core.OptionalUint64{ + HasValue: true, + Value: 20, + }, + ApiInterface: api.NewNoApiInterface(), + MinNodesPerShard: 3, + MetaChainMinNodes: 3, }) require.Nil(t, err) require.NotNil(t, chainSimulator) + for i := 0; i < 8; i++ { + err = chainSimulator.ForceChangeOfEpoch() + require.Nil(t, err) + } + + err = chainSimulator.GenerateBlocks(50) + require.Nil(t, err) + time.Sleep(time.Second) chainSimulator.Close() diff --git a/node/chainSimulator/components/coreComponents.go b/node/chainSimulator/components/coreComponents.go index af28ce185ff..f66fc43c9ab 100644 --- a/node/chainSimulator/components/coreComponents.go +++ b/node/chainSimulator/components/coreComponents.go @@ -6,12 +6,14 @@ import ( "time" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/chainparametersnotifier" "github.com/multiversx/mx-chain-go/common/enablers" factoryPubKey "github.com/multiversx/mx-chain-go/common/factory" + "github.com/multiversx/mx-chain-go/common/fieldsChecker" "github.com/multiversx/mx-chain-go/common/forking" + "github.com/multiversx/mx-chain-go/common/graceperiod" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/consensus" - "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/epochStart/notifier" "github.com/multiversx/mx-chain-go/factory" "github.com/multiversx/mx-chain-go/ntp" @@ -74,6 +76,10 @@ type coreComponentsHolder struct { processStatusHandler common.ProcessStatusHandler hardforkTriggerPubKey []byte enableEpochsHandler common.EnableEpochsHandler + chainParametersSubscriber process.ChainParametersSubscriber + chainParametersHandler process.ChainParametersHandler + fieldsSizeChecker common.FieldsSizeChecker + epochChangeGracePeriodHandler common.EpochChangeGracePeriodHandler } // ArgsCoreComponentsHolder will hold arguments needed for the core components holder @@ -145,10 +151,33 @@ func CreateCoreComponents(args ArgsCoreComponentsHolder) (*coreComponentsHolder, } instance.watchdog = &watchdog.DisabledWatchdog{} - instance.alarmScheduler = &mock.AlarmSchedulerStub{} + instance.alarmScheduler = &testscommon.AlarmSchedulerStub{} instance.syncTimer = &testscommon.SyncTimerStub{} - instance.genesisNodesSetup, err = sharding.NewNodesSetup(args.NodesSetupPath, instance.addressPubKeyConverter, instance.validatorPubKeyConverter, args.NumShards) + instance.epochStartNotifierWithConfirm = notifier.NewEpochStartSubscriptionHandler() + instance.chainParametersSubscriber = chainparametersnotifier.NewChainParametersNotifier() + chainParametersNotifier := chainparametersnotifier.NewChainParametersNotifier() + argsChainParametersHandler := sharding.ArgsChainParametersHolder{ + EpochStartEventNotifier: instance.epochStartNotifierWithConfirm, + ChainParameters: args.Config.GeneralSettings.ChainParametersByEpoch, + ChainParametersNotifier: chainParametersNotifier, + } + instance.chainParametersHandler, err = sharding.NewChainParametersHolder(argsChainParametersHandler) + if err != nil { + return nil, err + } + + instance.epochChangeGracePeriodHandler, err = graceperiod.NewEpochChangeGracePeriod(args.Config.GeneralSettings.EpochChangeGracePeriodByEpoch) + if err != nil { + return nil, err + } + + var nodesSetup config.NodesConfig + err = core.LoadJsonFile(&nodesSetup, args.NodesSetupPath) + if err != nil { + return nil, err + } + instance.genesisNodesSetup, err = sharding.NewNodesSetup(nodesSetup, instance.chainParametersHandler, instance.addressPubKeyConverter, instance.validatorPubKeyConverter, args.NumShards) if err != nil { return nil, err } @@ -164,10 +193,6 @@ func CreateCoreComponents(args ArgsCoreComponentsHolder) (*coreComponentsHolder, return nil, err } - if err != nil { - return nil, err - } - argsEconomicsHandler := economics.ArgsNewEconomicsData{ TxVersionChecker: instance.txVersionChecker, Economics: &args.EconomicsConfig, @@ -184,12 +209,10 @@ func CreateCoreComponents(args ArgsCoreComponentsHolder) (*coreComponentsHolder, instance.apiEconomicsData = instance.economicsData instance.ratingsData, err = rating.NewRatingsData(rating.RatingsDataArg{ - Config: args.RatingConfig, - ShardConsensusSize: args.ConsensusGroupSize, - MetaConsensusSize: args.MetaChainConsensusGroupSize, - ShardMinNodes: args.MinNodesPerShard, - MetaMinNodes: args.MinNodesMeta, - RoundDurationMiliseconds: args.RoundDurationInMs, + EpochNotifier: instance.epochNotifier, + Config: args.RatingConfig, + ChainParametersHolder: instance.chainParametersHandler, + RoundDurationMilliseconds: args.RoundDurationInMs, }) if err != nil { return nil, err @@ -201,10 +224,6 @@ func CreateCoreComponents(args ArgsCoreComponentsHolder) (*coreComponentsHolder, } instance.nodesShuffler, err = nodesCoordinator.NewHashValidatorsShuffler(&nodesCoordinator.NodesShufflerArgs{ - NodesShard: args.MinNodesPerShard, - NodesMeta: args.MinNodesMeta, - Hysteresis: 0, - Adaptivity: false, ShuffleBetweenShards: true, MaxNodesEnableConfig: args.EnableEpochsConfig.MaxNodesChangeEnableEpoch, EnableEpochsHandler: instance.enableEpochsHandler, @@ -220,7 +239,6 @@ func CreateCoreComponents(args ArgsCoreComponentsHolder) (*coreComponentsHolder, return nil, err } - instance.epochStartNotifierWithConfirm = notifier.NewEpochStartSubscriptionHandler() instance.chanStopNodeProcess = args.ChanStopNodeProcess instance.genesisTime = time.Unix(instance.genesisNodesSetup.GetStartTime(), 0) instance.chainID = args.Config.GeneralSettings.ChainID @@ -239,6 +257,12 @@ func CreateCoreComponents(args ArgsCoreComponentsHolder) (*coreComponentsHolder, } instance.hardforkTriggerPubKey = pubKeyBytes + fchecker, err := fieldsChecker.NewFieldsSizeChecker(instance.chainParametersHandler, hasher) + if err != nil { + return nil, err + } + instance.fieldsSizeChecker = fchecker + instance.collectClosableComponents() return instance, nil @@ -430,6 +454,26 @@ func (c *coreComponentsHolder) EnableEpochsHandler() common.EnableEpochsHandler return c.enableEpochsHandler } +// ChainParametersSubscriber will return the chain parameters subscriber +func (c *coreComponentsHolder) ChainParametersSubscriber() process.ChainParametersSubscriber { + return c.chainParametersSubscriber +} + +// ChainParametersHandler will return the chain parameters handler +func (c *coreComponentsHolder) ChainParametersHandler() process.ChainParametersHandler { + return c.chainParametersHandler +} + +// FieldsSizeChecker will return the fields size checker component +func (c *coreComponentsHolder) FieldsSizeChecker() common.FieldsSizeChecker { + return c.fieldsSizeChecker +} + +// EpochChangeGracePeriodHandler will return the epoch change grace period handler +func (c *coreComponentsHolder) EpochChangeGracePeriodHandler() common.EpochChangeGracePeriodHandler { + return c.epochChangeGracePeriodHandler +} + func (c *coreComponentsHolder) collectClosableComponents() { c.closeHandler.AddComponent(c.alarmScheduler) c.closeHandler.AddComponent(c.syncTimer) diff --git a/node/chainSimulator/components/coreComponents_test.go b/node/chainSimulator/components/coreComponents_test.go index d03310f6165..3517fa1605a 100644 --- a/node/chainSimulator/components/coreComponents_test.go +++ b/node/chainSimulator/components/coreComponents_test.go @@ -7,6 +7,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data/endProcess" "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/testscommon/components" + "github.com/multiversx/mx-chain-go/config" ) @@ -30,18 +32,32 @@ func createArgsCoreComponentsHolder() ArgsCoreComponentsHolder { }, AddressPubkeyConverter: config.PubkeyConfig{ Length: 32, - Type: "hex", + Type: "bech32", + Hrp: "erd", }, ValidatorPubkeyConverter: config.PubkeyConfig{ - Length: 128, + Length: 96, Type: "hex", }, GeneralSettings: config.GeneralSettingsConfig{ ChainID: "T", MinTransactionVersion: 1, + ChainParametersByEpoch: []config.ChainParametersByEpochConfig{ + { + EnableEpoch: 0, + RoundDuration: 4000, + ShardConsensusGroupSize: 1, + ShardMinNumNodes: 1, + MetachainConsensusGroupSize: 1, + MetachainMinNumNodes: 1, + Hysteresis: 0, + Adaptivity: false, + }, + }, + EpochChangeGracePeriodByEpoch: []config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}, }, Hardfork: config.HardforkConfig{ - PublicKeyToListenFrom: "41378f754e2c7b2745208c3ed21b151d297acdc84c3aca00b9e292cf28ec2d444771070157ea7760ed83c26f4fed387d0077e00b563a95825dac2cbc349fc0025ccf774e37b0a98ad9724d30e90f8c29b4091ccb738ed9ffc0573df776ee9ea30b3c038b55e532760ea4a8f152f2a52848020e5cee1cc537f2c2323399723081", + PublicKeyToListenFrom: components.DummyPk, }, }, EnableEpochsConfig: config.EnableEpochs{}, @@ -87,7 +103,7 @@ func createArgsCoreComponentsHolder() ArgsCoreComponentsHolder { LeaderPercentage: 0.1, DeveloperPercentage: 0.1, ProtocolSustainabilityPercentage: 0.1, - ProtocolSustainabilityAddress: "2c5594ae2f77a913119bc9db52833245a5879674cd4aeaedcd92f6f9e7edf17d", // tests use hex address pub key conv + ProtocolSustainabilityAddress: testingProtocolSustainabilityAddress, TopUpGradientPoint: "300000000000000000000", TopUpFactor: 0.25, EpochEnable: 0, @@ -108,27 +124,31 @@ func createArgsCoreComponentsHolder() ArgsCoreComponentsHolder { }, }, ShardChain: config.ShardChain{ - RatingSteps: config.RatingSteps{ - HoursToMaxRatingFromStartRating: 2, - ProposerValidatorImportance: 1, - ProposerDecreaseFactor: -4, - ValidatorDecreaseFactor: -4, - ConsecutiveMissedBlocksPenalty: 1.2, + RatingStepsByEpoch: []config.RatingSteps{ + { + HoursToMaxRatingFromStartRating: 2, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: 1.2, + }, }, }, MetaChain: config.MetaChain{ - RatingSteps: config.RatingSteps{ - HoursToMaxRatingFromStartRating: 2, - ProposerValidatorImportance: 1, - ProposerDecreaseFactor: -4, - ValidatorDecreaseFactor: -4, - ConsecutiveMissedBlocksPenalty: 1.3, + RatingStepsByEpoch: []config.RatingSteps{ + { + HoursToMaxRatingFromStartRating: 2, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: 1.3, + }, }, }, }, ChanStopNodeProcess: make(chan endProcess.ArgEndProcess), InitialRound: 0, - NodesSetupPath: "../../../sharding/mock/testdata/nodesSetupMock.json", + NodesSetupPath: "../../../cmd/node/config/nodesSetup.json", GasScheduleFilename: "../../../cmd/node/config/gasSchedules/gasScheduleV8.toml", NumShards: 3, WorkingDir: ".", @@ -293,8 +313,8 @@ func TestCoreComponents_GettersSetters(t *testing.T) { require.Equal(t, "T", comp.ChainID()) require.Equal(t, uint32(1), comp.MinTransactionVersion()) require.NotNil(t, comp.TxVersionChecker()) - require.Equal(t, uint32(64), comp.EncodedAddressLen()) - hfPk, _ := hex.DecodeString("41378f754e2c7b2745208c3ed21b151d297acdc84c3aca00b9e292cf28ec2d444771070157ea7760ed83c26f4fed387d0077e00b563a95825dac2cbc349fc0025ccf774e37b0a98ad9724d30e90f8c29b4091ccb738ed9ffc0573df776ee9ea30b3c038b55e532760ea4a8f152f2a52848020e5cee1cc537f2c2323399723081") + require.Equal(t, uint32(62), comp.EncodedAddressLen()) + hfPk, _ := hex.DecodeString(components.DummyPk) require.Equal(t, hfPk, comp.HardforkTriggerPubKey()) require.NotNil(t, comp.NodeTypeProvider()) require.NotNil(t, comp.WasmVMChangeLocker()) diff --git a/node/chainSimulator/components/dataComponents_test.go b/node/chainSimulator/components/dataComponents_test.go index a74f0b751f6..9bd27c36eba 100644 --- a/node/chainSimulator/components/dataComponents_test.go +++ b/node/chainSimulator/components/dataComponents_test.go @@ -3,12 +3,14 @@ package components import ( "testing" + "github.com/stretchr/testify/require" + retriever "github.com/multiversx/mx-chain-go/dataRetriever" chainStorage "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/require" ) func createArgsDataComponentsHolder() ArgsDataComponentsHolder { @@ -21,7 +23,7 @@ func createArgsDataComponentsHolder() ArgsDataComponentsHolder { }, DataPool: &dataRetriever.PoolsHolderStub{ MiniBlocksCalled: func() chainStorage.Cacher { - return &testscommon.CacherStub{} + return &cache.CacherStub{} }, }, InternalMarshaller: &testscommon.MarshallerStub{}, diff --git a/node/chainSimulator/components/instantBroadcastMessenger_test.go b/node/chainSimulator/components/instantBroadcastMessenger_test.go index 361caa03bbc..84770316337 100644 --- a/node/chainSimulator/components/instantBroadcastMessenger_test.go +++ b/node/chainSimulator/components/instantBroadcastMessenger_test.go @@ -6,6 +6,8 @@ import ( "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus/mock" errorsMx "github.com/multiversx/mx-chain-go/errors" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/stretchr/testify/require" ) @@ -22,14 +24,14 @@ func TestNewInstantBroadcastMessenger(t *testing.T) { t.Run("nil shardCoordinator should error", func(t *testing.T) { t.Parallel() - mes, err := NewInstantBroadcastMessenger(&mock.BroadcastMessengerMock{}, nil) + mes, err := NewInstantBroadcastMessenger(&consensus.BroadcastMessengerMock{}, nil) require.Equal(t, errorsMx.ErrNilShardCoordinator, err) require.Nil(t, mes) }) t.Run("should work", func(t *testing.T) { t.Parallel() - mes, err := NewInstantBroadcastMessenger(&mock.BroadcastMessengerMock{}, &mock.ShardCoordinatorMock{}) + mes, err := NewInstantBroadcastMessenger(&consensus.BroadcastMessengerMock{}, &mock.ShardCoordinatorMock{}) require.NoError(t, err) require.NotNil(t, mes) }) @@ -41,7 +43,7 @@ func TestInstantBroadcastMessenger_IsInterfaceNil(t *testing.T) { var mes *instantBroadcastMessenger require.True(t, mes.IsInterfaceNil()) - mes, _ = NewInstantBroadcastMessenger(&mock.BroadcastMessengerMock{}, &mock.ShardCoordinatorMock{}) + mes, _ = NewInstantBroadcastMessenger(&consensus.BroadcastMessengerMock{}, &mock.ShardCoordinatorMock{}) require.False(t, mes.IsInterfaceNil()) } @@ -60,7 +62,7 @@ func TestInstantBroadcastMessenger_BroadcastBlockDataLeader(t *testing.T) { "topic_0": {[]byte("txs topic 0")}, "topic_1": {[]byte("txs topic 1")}, } - mes, err := NewInstantBroadcastMessenger(&mock.BroadcastMessengerMock{ + mes, err := NewInstantBroadcastMessenger(&consensus.BroadcastMessengerMock{ BroadcastMiniBlocksCalled: func(mbs map[uint32][]byte, bytes []byte) error { require.Equal(t, providedMBs, mbs) return expectedErr // for coverage only @@ -94,7 +96,7 @@ func TestInstantBroadcastMessenger_BroadcastBlockDataLeader(t *testing.T) { expectedTxs := map[string][][]byte{ "topic_0_META": {[]byte("txs topic meta")}, } - mes, err := NewInstantBroadcastMessenger(&mock.BroadcastMessengerMock{ + mes, err := NewInstantBroadcastMessenger(&consensus.BroadcastMessengerMock{ BroadcastMiniBlocksCalled: func(mbs map[uint32][]byte, bytes []byte) error { require.Equal(t, expectedMBs, mbs) return nil @@ -114,7 +116,7 @@ func TestInstantBroadcastMessenger_BroadcastBlockDataLeader(t *testing.T) { t.Run("shard, empty miniblocks should early exit", func(t *testing.T) { t.Parallel() - mes, err := NewInstantBroadcastMessenger(&mock.BroadcastMessengerMock{ + mes, err := NewInstantBroadcastMessenger(&consensus.BroadcastMessengerMock{ BroadcastMiniBlocksCalled: func(mbs map[uint32][]byte, bytes []byte) error { require.Fail(t, "should have not been called") return nil diff --git a/node/chainSimulator/components/manualRoundHandler.go b/node/chainSimulator/components/manualRoundHandler.go index 479cf63a1f5..4f951930058 100644 --- a/node/chainSimulator/components/manualRoundHandler.go +++ b/node/chainSimulator/components/manualRoundHandler.go @@ -27,6 +27,11 @@ func (handler *manualRoundHandler) IncrementIndex() { atomic.AddInt64(&handler.index, 1) } +// RevertOneRound - +func (handler *manualRoundHandler) RevertOneRound() { + atomic.AddInt64(&handler.index, -1) +} + // Index returns the current index func (handler *manualRoundHandler) Index() int64 { return atomic.LoadInt64(&handler.index) diff --git a/node/chainSimulator/components/nodeFacade.go b/node/chainSimulator/components/nodeFacade.go index 139053cbf94..f8f5c30f9e9 100644 --- a/node/chainSimulator/components/nodeFacade.go +++ b/node/chainSimulator/components/nodeFacade.go @@ -94,7 +94,6 @@ func (node *testOnlyProcessingNode) createFacade(configs config.Configs, apiInte nodePack.WithNetworkComponents(node.NetworkComponentsHolder), nodePack.WithInitialNodesPubKeys(node.CoreComponentsHolder.GenesisNodesSetup().InitialNodesPubKeys()), nodePack.WithRoundDuration(node.CoreComponentsHolder.GenesisNodesSetup().GetRoundDuration()), - nodePack.WithConsensusGroupSize(int(node.CoreComponentsHolder.GenesisNodesSetup().GetShardConsensusGroupSize())), nodePack.WithGenesisTime(node.CoreComponentsHolder.GenesisTime()), nodePack.WithConsensusType(configs.GeneralConfig.Consensus.Type), nodePack.WithRequestedItemsHandler(node.ProcessComponentsHolder.RequestedItemsHandler()), diff --git a/node/chainSimulator/components/processComponents_test.go b/node/chainSimulator/components/processComponents_test.go index 536cb21abfc..98d33013fa3 100644 --- a/node/chainSimulator/components/processComponents_test.go +++ b/node/chainSimulator/components/processComponents_test.go @@ -5,12 +5,16 @@ import ( "sync" "testing" + "github.com/multiversx/mx-chain-core-go/core" coreData "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/endProcess" "github.com/multiversx/mx-chain-core-go/hashing/blake2b" "github.com/multiversx/mx-chain-core-go/hashing/keccak" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/require" + commonFactory "github.com/multiversx/mx-chain-go/common/factory" + "github.com/multiversx/mx-chain-go/common/graceperiod" disabledStatistics "github.com/multiversx/mx-chain-go/common/statistics/disabled" "github.com/multiversx/mx-chain-go/config" retriever "github.com/multiversx/mx-chain-go/dataRetriever" @@ -20,6 +24,7 @@ import ( chainStorage "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" "github.com/multiversx/mx-chain-go/testscommon/components" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" @@ -37,7 +42,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/testscommon/storage" updateMocks "github.com/multiversx/mx-chain-go/update/mock" - "github.com/stretchr/testify/require" ) const testingProtocolSustainabilityAddress = "erd1932eft30w753xyvme8d49qejgkjc09n5e49w4mwdjtm0neld797su0dlxp" @@ -57,8 +61,10 @@ var ( ) func createArgsProcessComponentsHolder() ArgsProcessComponentsHolder { - nodesSetup, _ := sharding.NewNodesSetup("../../../integrationTests/factory/testdata/nodesSetup.json", addrPubKeyConv, valPubKeyConv, 3) - + var nodesConfig config.NodesConfig + _ = core.LoadJsonFile(&nodesConfig, "../../../integrationTests/factory/testdata/nodesSetup.json") + nodesSetup, _ := sharding.NewNodesSetup(nodesConfig, &chainParameters.ChainParametersHolderMock{}, addrPubKeyConv, valPubKeyConv, 3) + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) args := ArgsProcessComponentsHolder{ Config: testscommon.GetGeneralConfig(), EpochConfig: config.EpochConfig{ @@ -154,22 +160,23 @@ func createArgsProcessComponentsHolder() ArgsProcessComponentsHolder { return big.NewInt(0).Mul(big.NewInt(1000000000000000000), big.NewInt(20000000)) }, }, - Hash: blake2b.NewBlake2b(), - TxVersionCheckHandler: &testscommon.TxVersionCheckerStub{}, - RatingHandler: &testscommon.RaterMock{}, - EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, - EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, - EpochNotifierWithConfirm: &updateMocks.EpochStartNotifierStub{}, - RoundHandlerField: &testscommon.RoundHandlerMock{}, - RoundChangeNotifier: &epochNotifier.RoundNotifierStub{}, - ChanStopProcess: make(chan endProcess.ArgEndProcess, 1), - TxSignHasherField: keccak.NewKeccak(), - HardforkTriggerPubKeyField: []byte("hardfork pub key"), - WasmVMChangeLockerInternal: &sync.RWMutex{}, - NodeTypeProviderField: &nodeTypeProviderMock.NodeTypeProviderStub{}, - RatingsConfig: &testscommon.RatingsInfoMock{}, - PathHdl: &testscommon.PathManagerStub{}, - ProcessStatusHandlerInternal: &testscommon.ProcessStatusHandlerStub{}, + Hash: blake2b.NewBlake2b(), + TxVersionCheckHandler: &testscommon.TxVersionCheckerStub{}, + RatingHandler: &testscommon.RaterMock{}, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, + EpochNotifierWithConfirm: &updateMocks.EpochStartNotifierStub{}, + RoundHandlerField: &testscommon.RoundHandlerMock{}, + RoundChangeNotifier: &epochNotifier.RoundNotifierStub{}, + ChanStopProcess: make(chan endProcess.ArgEndProcess, 1), + TxSignHasherField: keccak.NewKeccak(), + HardforkTriggerPubKeyField: []byte("hardfork pub key"), + WasmVMChangeLockerInternal: &sync.RWMutex{}, + NodeTypeProviderField: &nodeTypeProviderMock.NodeTypeProviderStub{}, + RatingsConfig: &testscommon.RatingsInfoMock{}, + PathHdl: &testscommon.PathManagerStub{}, + ProcessStatusHandlerInternal: &testscommon.ProcessStatusHandlerStub{}, + EpochChangeGracePeriodHandlerField: gracePeriod, }, CryptoComponents: &mock.CryptoComponentsStub{ BlKeyGen: &cryptoMocks.KeyGenStub{}, diff --git a/node/chainSimulator/components/stateComponents.go b/node/chainSimulator/components/stateComponents.go index b3fddf55f40..998263a8d7a 100644 --- a/node/chainSimulator/components/stateComponents.go +++ b/node/chainSimulator/components/stateComponents.go @@ -29,6 +29,7 @@ type stateComponentsHolder struct { triesContainer common.TriesHolder triesStorageManager map[string]common.StorageManager missingTrieNodesNotifier common.MissingTrieNodesNotifier + trieLeavesRetriever common.TrieLeavesRetriever stateComponentsCloser io.Closer } @@ -70,6 +71,7 @@ func CreateStateComponents(args ArgsStateComponents) (*stateComponentsHolder, er triesContainer: stateComp.TriesContainer(), triesStorageManager: stateComp.TrieStorageManagers(), missingTrieNodesNotifier: stateComp.MissingTrieNodesNotifier(), + trieLeavesRetriever: stateComp.TrieLeavesRetriever(), stateComponentsCloser: stateComp, }, nil } @@ -109,6 +111,11 @@ func (s *stateComponentsHolder) MissingTrieNodesNotifier() common.MissingTrieNod return s.missingTrieNodesNotifier } +// TrieLeavesRetriever will return the trie leaves retriever +func (s *stateComponentsHolder) TrieLeavesRetriever() common.TrieLeavesRetriever { + return s.trieLeavesRetriever +} + // Close will close the state components func (s *stateComponentsHolder) Close() error { return s.stateComponentsCloser.Close() diff --git a/node/chainSimulator/components/statusCoreComponents_test.go b/node/chainSimulator/components/statusCoreComponents_test.go index c1a6a3336a7..3040a863ca9 100644 --- a/node/chainSimulator/components/statusCoreComponents_test.go +++ b/node/chainSimulator/components/statusCoreComponents_test.go @@ -46,6 +46,7 @@ func createArgs() (config.Configs, factory.CoreComponentsHolder) { IntMarsh: &testscommon.MarshallerStub{}, UInt64ByteSliceConv: &mockTests.Uint64ByteSliceConverterMock{}, NodesConfig: &genesisMocks.NodesSetupStub{}, + RatingsConfig: &testscommon.RatingsInfoMock{}, } } diff --git a/node/chainSimulator/components/storageService.go b/node/chainSimulator/components/storageService.go index 9a2a7c4860f..3a3f42f28b6 100644 --- a/node/chainSimulator/components/storageService.go +++ b/node/chainSimulator/components/storageService.go @@ -29,6 +29,7 @@ func CreateStore(numOfShards uint32) dataRetriever.StorageService { store.AddStorer(dataRetriever.EpochByHashUnit, CreateMemUnit()) store.AddStorer(dataRetriever.ResultsHashesByTxHashUnit, CreateMemUnit()) store.AddStorer(dataRetriever.TrieEpochRootHashUnit, CreateMemUnit()) + store.AddStorer(dataRetriever.ProofsUnit, CreateMemUnit()) for i := uint32(0); i < numOfShards; i++ { hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(i) diff --git a/node/chainSimulator/components/storageService_test.go b/node/chainSimulator/components/storageService_test.go index 3be398b53e6..c6df371fc1d 100644 --- a/node/chainSimulator/components/storageService_test.go +++ b/node/chainSimulator/components/storageService_test.go @@ -37,6 +37,7 @@ func TestCreateStore(t *testing.T) { dataRetriever.ResultsHashesByTxHashUnit, dataRetriever.TrieEpochRootHashUnit, dataRetriever.ShardHdrNonceHashDataUnit, + dataRetriever.ProofsUnit, dataRetriever.UnitType(101), // shard 2 } diff --git a/node/chainSimulator/components/syncedBroadcastNetwork_test.go b/node/chainSimulator/components/syncedBroadcastNetwork_test.go index 74e061a819a..535b3808203 100644 --- a/node/chainSimulator/components/syncedBroadcastNetwork_test.go +++ b/node/chainSimulator/components/syncedBroadcastNetwork_test.go @@ -6,8 +6,9 @@ import ( "github.com/multiversx/mx-chain-communication-go/p2p" "github.com/multiversx/mx-chain-core-go/core" - "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" ) func TestSyncedBroadcastNetwork_BroadcastShouldWorkOn3Peers(t *testing.T) { @@ -181,6 +182,7 @@ func TestSyncedBroadcastNetwork_SendDirectlyShouldNotDeadlock(t *testing.T) { topic := "topic" testMessage := []byte("test message") + msgID := []byte("msgID") peer1, err := NewSyncedMessenger(network) assert.Nil(t, err) @@ -191,9 +193,9 @@ func TestSyncedBroadcastNetwork_SendDirectlyShouldNotDeadlock(t *testing.T) { peer2, err := NewSyncedMessenger(network) assert.Nil(t, err) processor2 := &p2pmocks.MessageProcessorStub{ - ProcessReceivedMessageCalled: func(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error { + ProcessReceivedMessageCalled: func(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) { log.Debug("sending message back to", "pid", fromConnectedPeer.Pretty()) - return source.SendToConnectedPeer(message.Topic(), []byte("reply: "+string(message.Data())), fromConnectedPeer) + return msgID, source.SendToConnectedPeer(message.Topic(), []byte("reply: "+string(message.Data())), fromConnectedPeer) }, } _ = peer2.CreateTopic(topic, true) @@ -285,7 +287,7 @@ func TestSyncedBroadcastNetwork_GetConnectedPeersOnTopic(t *testing.T) { func createMessageProcessor(t *testing.T, dataMap map[core.PeerID]map[string][]byte, pid core.PeerID) p2p.MessageProcessor { return &p2pmocks.MessageProcessorStub{ - ProcessReceivedMessageCalled: func(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error { + ProcessReceivedMessageCalled: func(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) { m, found := dataMap[pid] if !found { m = make(map[string][]byte) @@ -297,7 +299,7 @@ func createMessageProcessor(t *testing.T, dataMap map[core.PeerID]map[string][]b assert.Equal(t, message.Peer(), fromConnectedPeer) m[message.Topic()] = message.Data() - return nil + return nil, nil }, } } diff --git a/node/chainSimulator/components/syncedMessenger.go b/node/chainSimulator/components/syncedMessenger.go index 09786c45842..70139f7c54b 100644 --- a/node/chainSimulator/components/syncedMessenger.go +++ b/node/chainSimulator/components/syncedMessenger.go @@ -85,7 +85,7 @@ func (messenger *syncedMessenger) receive(fromConnectedPeer core.PeerID, message for _, handler := range handlers { // this is needed to process all received messages on multiple go routines go func(proc p2p.MessageProcessor, p2pMessage p2p.MessageP2P, peer core.PeerID, localWG *sync.WaitGroup) { - err := proc.ProcessReceivedMessage(p2pMessage, peer, messenger) + _, err := proc.ProcessReceivedMessage(p2pMessage, peer, messenger) if err != nil { log.Trace("received message syncedMessenger", "error", err, "topic", p2pMessage.Topic(), "from connected peer", peer.Pretty()) } @@ -98,8 +98,8 @@ func (messenger *syncedMessenger) receive(fromConnectedPeer core.PeerID, message } // ProcessReceivedMessage does nothing and returns nil -func (messenger *syncedMessenger) ProcessReceivedMessage(_ p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) error { - return nil +func (messenger *syncedMessenger) ProcessReceivedMessage(_ p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) ([]byte, error) { + return nil, nil } // CreateTopic will create a topic for receiving data diff --git a/node/chainSimulator/components/syncedMessenger_test.go b/node/chainSimulator/components/syncedMessenger_test.go index c8c17918141..b8bb82f5342 100644 --- a/node/chainSimulator/components/syncedMessenger_test.go +++ b/node/chainSimulator/components/syncedMessenger_test.go @@ -4,8 +4,9 @@ import ( "fmt" "testing" - "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" ) func TestNewSyncedMessenger(t *testing.T) { @@ -58,7 +59,9 @@ func TestSyncedMessenger_DisabledMethodsShouldNotPanic(t *testing.T) { assert.Nil(t, messenger.SetPeerShardResolver(nil)) assert.Nil(t, messenger.ConnectToPeer("")) assert.Nil(t, messenger.Bootstrap()) - assert.Nil(t, messenger.ProcessReceivedMessage(nil, "", nil)) + msgID, err := messenger.ProcessReceivedMessage(nil, "", nil) + assert.Nil(t, err) + assert.Nil(t, msgID) messenger.WaitForConnections(0, 0) diff --git a/node/chainSimulator/components/testOnlyProcessingNode.go b/node/chainSimulator/components/testOnlyProcessingNode.go index bc8f9b8de1a..8e6148f40f5 100644 --- a/node/chainSimulator/components/testOnlyProcessingNode.go +++ b/node/chainSimulator/components/testOnlyProcessingNode.go @@ -75,6 +75,8 @@ type testOnlyProcessingNode struct { httpServer shared.UpgradeableHttpServerHandler facadeHandler shared.FacadeHandler + + basePeers map[uint32]core.PeerID } // NewTestOnlyProcessingNode creates a new instance of a node that is able to only process transactions @@ -316,6 +318,7 @@ func (node *testOnlyProcessingNode) createNodesCoordinator(pref config.Preferenc node.CoreComponentsHolder.EnableEpochsHandler(), node.DataPool.CurrentEpochValidatorInfo(), node.BootstrapComponentsHolder.NodesCoordinatorRegistryFactory(), + node.CoreComponentsHolder.ChainParametersHandler(), ) if err != nil { return err @@ -341,6 +344,7 @@ func (node *testOnlyProcessingNode) createBroadcastMessenger() error { } node.broadcastMessenger, err = NewInstantBroadcastMessenger(broadcastMessenger, node.BootstrapComponentsHolder.ShardCoordinator()) + return err } @@ -394,6 +398,11 @@ func (node *testOnlyProcessingNode) GetStatusCoreComponents() factory.StatusCore return node.StatusCoreComponents } +// NetworkComponents will return the network components +func (node *testOnlyProcessingNode) GetNetworkComponents() factory.NetworkComponentsHolder { + return node.NetworkComponentsHolder +} + func (node *testOnlyProcessingNode) collectClosableComponents(apiInterface APIConfigurator) { node.closeHandler.AddComponent(node.ProcessComponentsHolder) node.closeHandler.AddComponent(node.DataComponentsHolder) @@ -609,6 +618,16 @@ func (node *testOnlyProcessingNode) getUserAccount(address []byte) (state.UserAc return userAccount, nil } +// GetBasePeers returns return network messenger ids for base nodes +func (node *testOnlyProcessingNode) GetBasePeers() map[uint32]core.PeerID { + return node.basePeers +} + +// SetBasePeers will set base network messenger id nodes per shard +func (node *testOnlyProcessingNode) SetBasePeers(basePeers map[uint32]core.PeerID) { + node.basePeers = basePeers +} + // Close will call the Close methods on all inner components func (node *testOnlyProcessingNode) Close() error { return node.closeHandler.Close() diff --git a/node/chainSimulator/configs/configs.go b/node/chainSimulator/configs/configs.go index 22fc863c7a0..e3ea5958dcc 100644 --- a/node/chainSimulator/configs/configs.go +++ b/node/chainSimulator/configs/configs.go @@ -52,6 +52,7 @@ type ArgsChainSimulatorConfigs struct { ConsensusGroupSize uint32 MetaChainMinNodes uint32 MetaChainConsensusGroupSize uint32 + Hysteresis float32 InitialEpoch uint32 RoundsPerEpoch core.OptionalUint64 NumNodesWaitingListShard uint32 @@ -115,7 +116,6 @@ func CreateChainSimulatorConfigs(args ArgsChainSimulatorConfigs) (*ArgsConfigsSi // set compatible trie configs configs.GeneralConfig.StateTriesConfig.SnapshotsEnabled = false - // enable db lookup extension configs.GeneralConfig.DbLookupExtensions.Enabled = true @@ -132,10 +132,13 @@ func CreateChainSimulatorConfigs(args ArgsChainSimulatorConfigs) (*ArgsConfigsSi return nil, err } + updateConfigsChainParameters(args, configs) node.ApplyArchCustomConfigs(configs) if args.AlterConfigsFunction != nil { args.AlterConfigsFunction(configs) + // this is needed to keep in sync Andromeda flag and the second entry from chain parameters + configs.GeneralConfig.GeneralSettings.ChainParametersByEpoch[1].EnableEpoch = configs.EpochConfig.EnableEpochs.AndromedaEnableEpoch } return &ArgsConfigsSimulator{ @@ -146,6 +149,21 @@ func CreateChainSimulatorConfigs(args ArgsChainSimulatorConfigs) (*ArgsConfigsSi }, nil } +func updateConfigsChainParameters(args ArgsChainSimulatorConfigs, configs *config.Configs) { + for idx := 0; idx < len(configs.GeneralConfig.GeneralSettings.ChainParametersByEpoch); idx++ { + configs.GeneralConfig.GeneralSettings.ChainParametersByEpoch[idx].ShardMinNumNodes = args.MinNodesPerShard + configs.GeneralConfig.GeneralSettings.ChainParametersByEpoch[idx].MetachainMinNumNodes = args.MetaChainMinNodes + configs.GeneralConfig.GeneralSettings.ChainParametersByEpoch[idx].MetachainConsensusGroupSize = args.MetaChainConsensusGroupSize + configs.GeneralConfig.GeneralSettings.ChainParametersByEpoch[idx].ShardConsensusGroupSize = args.ConsensusGroupSize + configs.GeneralConfig.GeneralSettings.ChainParametersByEpoch[idx].RoundDuration = args.RoundDurationInMillis + configs.GeneralConfig.GeneralSettings.ChainParametersByEpoch[idx].Hysteresis = args.Hysteresis + } + + if len(configs.GeneralConfig.GeneralSettings.ChainParametersByEpoch) > 1 { + configs.GeneralConfig.GeneralSettings.ChainParametersByEpoch[1].ShardConsensusGroupSize = args.MinNodesPerShard + } +} + // SetMaxNumberOfNodesInConfigs will correctly set the max number of nodes in configs func SetMaxNumberOfNodesInConfigs(cfg *config.Configs, eligibleNodes uint32, waitingNodes uint32, numOfShards uint32) { cfg.SystemSCConfig.StakingSystemSCConfig.MaxNumberOfNodesForStake = uint64(eligibleNodes + waitingNodes) @@ -168,10 +186,10 @@ func SetMaxNumberOfNodesInConfigs(cfg *config.Configs, eligibleNodes uint32, wai // SetQuickJailRatingConfig will set the rating config in a way that leads to rapid jailing of a node func SetQuickJailRatingConfig(cfg *config.Configs) { - cfg.RatingsConfig.ShardChain.RatingSteps.ConsecutiveMissedBlocksPenalty = 100 - cfg.RatingsConfig.ShardChain.RatingSteps.HoursToMaxRatingFromStartRating = 1 - cfg.RatingsConfig.MetaChain.RatingSteps.ConsecutiveMissedBlocksPenalty = 100 - cfg.RatingsConfig.MetaChain.RatingSteps.HoursToMaxRatingFromStartRating = 1 + cfg.RatingsConfig.ShardChain.RatingStepsByEpoch[0].ConsecutiveMissedBlocksPenalty = 100 + cfg.RatingsConfig.ShardChain.RatingStepsByEpoch[0].HoursToMaxRatingFromStartRating = 1 + cfg.RatingsConfig.MetaChain.RatingStepsByEpoch[0].ConsecutiveMissedBlocksPenalty = 100 + cfg.RatingsConfig.MetaChain.RatingStepsByEpoch[0].HoursToMaxRatingFromStartRating = 1 } // SetStakingV4ActivationEpochs configures activation epochs for Staking V4. @@ -280,14 +298,10 @@ func generateValidatorsKeyAndUpdateFiles( nodes.RoundDuration = args.RoundDurationInMillis nodes.StartTime = args.GenesisTimeStamp - nodes.ConsensusGroupSize = args.ConsensusGroupSize - nodes.MetaChainConsensusGroupSize = args.MetaChainConsensusGroupSize nodes.Hysteresis = 0 - nodes.MinNodesPerShard = args.MinNodesPerShard - nodes.MetaChainMinNodes = args.MetaChainMinNodes - nodes.InitialNodes = make([]*sharding.InitialNode, 0) + configs.NodesConfig.InitialNodes = make([]*config.InitialNodeConfig, 0) privateKeys := make([]crypto.PrivateKey, 0) publicKeys := make([]crypto.PublicKey, 0) walletIndex := 0 @@ -307,6 +321,12 @@ func generateValidatorsKeyAndUpdateFiles( Address: stakeWallets[walletIndex].Address.Bech32, }) + configs.NodesConfig.InitialNodes = append(configs.NodesConfig.InitialNodes, &config.InitialNodeConfig{ + PubKey: hex.EncodeToString(pkBytes), + Address: stakeWallets[walletIndex].Address.Bech32, + InitialRating: 5000001, + }) + walletIndex++ } @@ -326,6 +346,13 @@ func generateValidatorsKeyAndUpdateFiles( PubKey: hex.EncodeToString(pkBytes), Address: stakeWallets[walletIndex].Address.Bech32, }) + + configs.NodesConfig.InitialNodes = append(configs.NodesConfig.InitialNodes, &config.InitialNodeConfig{ + PubKey: hex.EncodeToString(pkBytes), + Address: stakeWallets[walletIndex].Address.Bech32, + InitialRating: 5000001, + }) + walletIndex++ } } diff --git a/node/chainSimulator/disabled/antiflooder.go b/node/chainSimulator/disabled/antiflooder.go index 0d4c45fd0e3..1e705b29c47 100644 --- a/node/chainSimulator/disabled/antiflooder.go +++ b/node/chainSimulator/disabled/antiflooder.go @@ -16,6 +16,10 @@ func NewAntiFlooder() *antiFlooder { return &antiFlooder{} } +// SetConsensusSizeNotifier does nothing +func (a *antiFlooder) SetConsensusSizeNotifier(_ process.ChainParametersSubscriber, _ uint32) { +} + // CanProcessMessage returns nil func (a *antiFlooder) CanProcessMessage(_ p2p.MessageP2P, _ core.PeerID) error { return nil diff --git a/node/chainSimulator/interface.go b/node/chainSimulator/interface.go index 0b2f51ca457..0c93ad65b1f 100644 --- a/node/chainSimulator/interface.go +++ b/node/chainSimulator/interface.go @@ -1,6 +1,11 @@ package chainSimulator -import "github.com/multiversx/mx-chain-go/node/chainSimulator/process" +import ( + "math/big" + + "github.com/multiversx/mx-chain-go/node/chainSimulator/dtos" + "github.com/multiversx/mx-chain-go/node/chainSimulator/process" +) // ChainHandler defines what a chain handler should be able to do type ChainHandler interface { @@ -13,5 +18,7 @@ type ChainHandler interface { type ChainSimulator interface { GenerateBlocks(numOfBlocks int) error GetNodeHandler(shardID uint32) process.NodeHandler + GenerateAddressInShard(providedShardID uint32) dtos.WalletAddress + GenerateAndMintWalletAddress(targetShardID uint32, value *big.Int) (dtos.WalletAddress, error) IsInterfaceNil() bool } diff --git a/node/chainSimulator/process/interface.go b/node/chainSimulator/process/interface.go index 7ae2f07517e..c47234e800f 100644 --- a/node/chainSimulator/process/interface.go +++ b/node/chainSimulator/process/interface.go @@ -1,6 +1,7 @@ package process import ( + "github.com/multiversx/mx-chain-core-go/core" chainData "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-go/api/shared" "github.com/multiversx/mx-chain-go/consensus" @@ -22,10 +23,13 @@ type NodeHandler interface { GetStateComponents() factory.StateComponentsHolder GetFacadeHandler() shared.FacadeHandler GetStatusCoreComponents() factory.StatusCoreComponentsHolder + GetNetworkComponents() factory.NetworkComponentsHolder SetKeyValueForAddress(addressBytes []byte, state map[string]string) error SetStateForAddress(address []byte, state *dtos.AddressState) error RemoveAccount(address []byte) error ForceChangeOfEpoch() error + GetBasePeers() map[uint32]core.PeerID + SetBasePeers(basePeers map[uint32]core.PeerID) Close() error IsInterfaceNil() bool } diff --git a/node/chainSimulator/process/processor.go b/node/chainSimulator/process/processor.go index 1c9819e27f0..50305440e76 100644 --- a/node/chainSimulator/process/processor.go +++ b/node/chainSimulator/process/processor.go @@ -3,12 +3,14 @@ package process import ( "time" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + dataBlock "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-go/common" - "github.com/multiversx/mx-chain-go/consensus/spos" heartbeatData "github.com/multiversx/mx-chain-go/heartbeat/data" "github.com/multiversx/mx-chain-go/node/chainSimulator/configs" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" logger "github.com/multiversx/mx-chain-logger-go" ) @@ -46,10 +48,13 @@ func (creator *blocksCreator) IncrementRound() { // CreateNewBlock creates and process a new block func (creator *blocksCreator) CreateNewBlock() error { - bp := creator.nodeHandler.GetProcessComponents().BlockProcessor() + processComponents := creator.nodeHandler.GetProcessComponents() + cryptoComponents := creator.nodeHandler.GetCryptoComponents() + coreComponents := creator.nodeHandler.GetCoreComponents() + bp := processComponents.BlockProcessor() - nonce, _, prevHash, prevRandSeed, epoch := creator.getPreviousHeaderData() - round := creator.nodeHandler.GetCoreComponents().RoundHandler().Index() + nonce, _, prevHash, prevRandSeed, epoch, prevHeader := creator.getPreviousHeaderData() + round := coreComponents.RoundHandler().Index() newHeader, err := bp.CreateNewHeader(uint64(round), nonce+1) if err != nil { return err @@ -71,38 +76,50 @@ func (creator *blocksCreator) CreateNewBlock() error { return err } - err = newHeader.SetPubKeysBitmap([]byte{1}) + err = newHeader.SetChainID([]byte(configs.ChainID)) if err != nil { return err } - err = newHeader.SetChainID([]byte(configs.ChainID)) + headerCreationTime := coreComponents.RoundHandler().TimeStamp() + err = newHeader.SetTimeStamp(uint64(headerCreationTime.Unix())) if err != nil { return err } - headerCreationTime := creator.nodeHandler.GetCoreComponents().RoundHandler().TimeStamp() - err = newHeader.SetTimeStamp(uint64(headerCreationTime.Unix())) + leader, validators, err := processComponents.NodesCoordinator().ComputeConsensusGroup(prevRandSeed, newHeader.GetRound(), shardID, epoch) if err != nil { return err } - validatorsGroup, err := creator.nodeHandler.GetProcessComponents().NodesCoordinator().ComputeConsensusGroup(prevRandSeed, newHeader.GetRound(), shardID, epoch) + pubKeyBitmap := GeneratePubKeyBitmap(len(validators)) + for idx, validator := range validators { + isManaged := cryptoComponents.KeysHandler().IsKeyManagedByCurrentNode(validator.PubKey()) + if isManaged { + continue + } + + err = UnsetBitInBitmap(idx, pubKeyBitmap) + if err != nil { + return err + } + } + + err = newHeader.SetPubKeysBitmap(pubKeyBitmap) if err != nil { return err } - blsKey := validatorsGroup[spos.IndexOfLeaderInConsensusGroup] - isManaged := creator.nodeHandler.GetCryptoComponents().KeysHandler().IsKeyManagedByCurrentNode(blsKey.PubKey()) + isManaged := cryptoComponents.KeysHandler().IsKeyManagedByCurrentNode(leader.PubKey()) if !isManaged { log.Debug("cannot propose block - leader bls key is missing", - "leader key", blsKey.PubKey(), + "leader key", leader.PubKey(), "shard", creator.nodeHandler.GetShardCoordinator().SelfId()) return nil } - signingHandler := creator.nodeHandler.GetCryptoComponents().ConsensusSigningHandler() - randSeed, err := signingHandler.CreateSignatureForPublicKey(newHeader.GetPrevRandSeed(), blsKey.PubKey()) + signingHandler := cryptoComponents.ConsensusSigningHandler() + randSeed, err := signingHandler.CreateSignatureForPublicKey(newHeader.GetPrevRandSeed(), leader.PubKey()) if err != nil { return err } @@ -111,6 +128,8 @@ func (creator *blocksCreator) CreateNewBlock() error { return err } + enableEpochHandler := coreComponents.EnableEpochsHandler() + header, block, err := bp.CreateBlock(newHeader, func() bool { return true }) @@ -118,7 +137,15 @@ func (creator *blocksCreator) CreateNewBlock() error { return err } - err = creator.setHeaderSignatures(header, blsKey.PubKey()) + prevHeaderStartOfEpoch := false + if prevHeader != nil { + prevHeaderStartOfEpoch = prevHeader.IsStartOfEpochBlock() + } + if prevHeaderStartOfEpoch { + creator.updatePeerShardMapper(header.GetEpoch()) + } + + headerProof, err := creator.ApplySignaturesAndGetProof(header, prevHeader, enableEpochHandler, validators, leader, pubKeyBitmap) if err != nil { return err } @@ -138,17 +165,131 @@ func (creator *blocksCreator) CreateNewBlock() error { return err } - err = creator.nodeHandler.GetBroadcastMessenger().BroadcastHeader(header, blsKey.PubKey()) + messenger := creator.nodeHandler.GetBroadcastMessenger() + err = messenger.BroadcastHeader(header, leader.PubKey()) if err != nil { return err } - err = creator.nodeHandler.GetBroadcastMessenger().BroadcastMiniBlocks(miniBlocks, blsKey.PubKey()) + if !check.IfNil(headerProof) { + err = messenger.BroadcastEquivalentProof(headerProof, leader.PubKey()) + if err != nil { + return err + } + } + + err = messenger.BroadcastMiniBlocks(miniBlocks, leader.PubKey()) if err != nil { return err } - return creator.nodeHandler.GetBroadcastMessenger().BroadcastTransactions(transactions, blsKey.PubKey()) + return messenger.BroadcastTransactions(transactions, leader.PubKey()) +} + +func (creator *blocksCreator) updatePeerShardMapper( + epoch uint32, +) { + peerShardMapper := creator.nodeHandler.GetProcessComponents().PeerShardMapper() + + nc := creator.nodeHandler.GetProcessComponents().NodesCoordinator() + + eligibleMaps, err := nc.GetAllEligibleValidatorsPublicKeys(epoch) + if err != nil { + log.Error("failed to get eligible validators map", "error", err) + return + } + + for shardID, eligibleMap := range eligibleMaps { + for _, pubKey := range eligibleMap { + peerID := creator.nodeHandler.GetBasePeers()[shardID] + + log.Debug("added custom peer mapping", "peerID", peerID.Pretty(), "shardID", shardID, "addrs", pubKey) + peerShardMapper.UpdatePeerIDInfo(peerID, pubKey, shardID) + } + } + +} + +// ApplySignaturesAndGetProof - +func (creator *blocksCreator) ApplySignaturesAndGetProof( + header data.HeaderHandler, + prevHeader data.HeaderHandler, + enableEpochHandler common.EnableEpochsHandler, + validators []nodesCoordinator.Validator, + leader nodesCoordinator.Validator, + pubKeyBitmap []byte, +) (*dataBlock.HeaderProof, error) { + nilPrevHeader := check.IfNil(prevHeader) + + err := creator.setHeaderSignatures(header, leader.PubKey(), validators) + if err != nil { + return nil, err + } + + coreComponents := creator.nodeHandler.GetCoreComponents() + hasher := coreComponents.Hasher() + marshaller := coreComponents.InternalMarshalizer() + headerHash, err := core.CalculateHash(marshaller, hasher, header) + if err != nil { + return nil, err + } + + pubKeys := extractValidatorPubKeys(validators) + newHeaderSig, err := creator.generateAggregatedSignature(headerHash, header.GetEpoch(), header.GetPubKeysBitmap(), pubKeys) + if err != nil { + return nil, err + } + + var headerProof *dataBlock.HeaderProof + shouldAddCurrentProof := !nilPrevHeader && enableEpochHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) + if shouldAddCurrentProof { + headerProof = createProofForHeader(pubKeyBitmap, newHeaderSig, headerHash, header) + creator.nodeHandler.GetDataComponents().Datapool().Headers().AddHeader(headerHash, header) + err = creator.nodeHandler.GetProcessComponents().HeaderSigVerifier().VerifyHeaderProof(headerProof) + if err != nil { + return nil, err + } + + dataPool := creator.nodeHandler.GetDataComponents().Datapool() + _ = dataPool.Proofs().AddProof(headerProof) + } + + return headerProof, nil +} + +func createProofForHeader(pubKeyBitmap, signature, headerHash []byte, header data.HeaderHandler) *dataBlock.HeaderProof { + return &dataBlock.HeaderProof{ + PubKeysBitmap: pubKeyBitmap, + AggregatedSignature: signature, + HeaderHash: headerHash, + HeaderEpoch: header.GetEpoch(), + HeaderNonce: header.GetNonce(), + HeaderShardId: header.GetShardID(), + HeaderRound: header.GetRound(), + IsStartOfEpoch: header.IsStartOfEpochBlock(), + } +} + +func (creator *blocksCreator) getPreviousHeaderData() (nonce, round uint64, prevHash, prevRandSeed []byte, epoch uint32, currentHeader data.HeaderHandler) { + chainHandler := creator.nodeHandler.GetChainHandler() + currentHeader = chainHandler.GetCurrentBlockHeader() + + if currentHeader != nil { + nonce, round = currentHeader.GetNonce(), currentHeader.GetRound() + prevHash = chainHandler.GetCurrentBlockHeaderHash() + prevRandSeed = currentHeader.GetRandSeed() + epoch = currentHeader.GetEpoch() + return + } + + roundHandler := creator.nodeHandler.GetCoreComponents().RoundHandler() + prevHash = chainHandler.GetGenesisHeaderHash() + prevRandSeed = chainHandler.GetGenesisHeader().GetRandSeed() + round = uint64(roundHandler.Index()) - 1 + epoch = chainHandler.GetGenesisHeader().GetEpoch() + nonce = chainHandler.GetGenesisHeader().GetNonce() + + return } func (creator *blocksCreator) setHeartBeat(header data.HeaderHandler) error { @@ -177,73 +318,80 @@ func (creator *blocksCreator) setHeartBeat(header data.HeaderHandler) error { return nil } -func (creator *blocksCreator) getPreviousHeaderData() (nonce, round uint64, prevHash, prevRandSeed []byte, epoch uint32) { - currentHeader := creator.nodeHandler.GetChainHandler().GetCurrentBlockHeader() - - if currentHeader != nil { - nonce, round = currentHeader.GetNonce(), currentHeader.GetRound() - prevHash = creator.nodeHandler.GetChainHandler().GetCurrentBlockHeaderHash() - prevRandSeed = currentHeader.GetRandSeed() - epoch = currentHeader.GetEpoch() - return - } - - prevHash = creator.nodeHandler.GetChainHandler().GetGenesisHeaderHash() - prevRandSeed = creator.nodeHandler.GetChainHandler().GetGenesisHeader().GetRandSeed() - round = uint64(creator.nodeHandler.GetCoreComponents().RoundHandler().Index()) - 1 - epoch = creator.nodeHandler.GetChainHandler().GetGenesisHeader().GetEpoch() - nonce = creator.nodeHandler.GetChainHandler().GetGenesisHeader().GetNonce() - - return -} - -func (creator *blocksCreator) setHeaderSignatures(header data.HeaderHandler, blsKeyBytes []byte) error { - signingHandler := creator.nodeHandler.GetCryptoComponents().ConsensusSigningHandler() +func (creator *blocksCreator) setHeaderSignatures( + header data.HeaderHandler, + blsKeyBytes []byte, + validators []nodesCoordinator.Validator, +) error { headerClone := header.ShallowClone() _ = headerClone.SetPubKeysBitmap(nil) - marshalizedHdr, err := creator.nodeHandler.GetCoreComponents().InternalMarshalizer().Marshal(headerClone) + marshalizedHdr, err := creator.nodeHandler.GetCoreComponents(). + InternalMarshalizer().Marshal(headerClone) if err != nil { return err } - err = signingHandler.Reset([]string{string(blsKeyBytes)}) + headerHash := creator.nodeHandler.GetCoreComponents().Hasher().Compute(string(marshalizedHdr)) + pubKeys := extractValidatorPubKeys(validators) + + sig, err := creator.generateAggregatedSignature(headerHash, header.GetEpoch(), header.GetPubKeysBitmap(), pubKeys) if err != nil { return err } - headerHash := creator.nodeHandler.GetCoreComponents().Hasher().Compute(string(marshalizedHdr)) - _, err = signingHandler.CreateSignatureShareForPublicKey( - headerHash, - uint16(0), - header.GetEpoch(), - blsKeyBytes, - ) - if err != nil { - return err + isEquivalentMessageEnabled := creator.nodeHandler.GetCoreComponents().EnableEpochsHandler().IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) + if !isEquivalentMessageEnabled { + if err = header.SetSignature(sig); err != nil { + return err + } } - sig, err := signingHandler.AggregateSigs(header.GetPubKeysBitmap(), header.GetEpoch()) + leaderSignature, err := creator.createLeaderSignature(header, blsKeyBytes) if err != nil { return err } - err = header.SetSignature(sig) + return header.SetLeaderSignature(leaderSignature) +} + +func (creator *blocksCreator) generateAggregatedSignature(headerHash []byte, epoch uint32, pubKeysBitmap []byte, pubKeys []string) ([]byte, error) { + signingHandler := creator.nodeHandler.GetCryptoComponents().ConsensusSigningHandler() + + err := signingHandler.Reset(pubKeys) if err != nil { - return err + return nil, err } - leaderSignature, err := creator.createLeaderSignature(header, blsKeyBytes) - if err != nil { - return err + totalKey := 0 + for idx, pubKey := range pubKeys { + isManaged := creator.nodeHandler.GetCryptoComponents().KeysHandler().IsKeyManagedByCurrentNode([]byte(pubKey)) + if !isManaged { + + continue + } + + totalKey++ + if _, err = signingHandler.CreateSignatureShareForPublicKey(headerHash, uint16(idx), epoch, []byte(pubKey)); err != nil { + return nil, err + } } - err = header.SetLeaderSignature(leaderSignature) + aggSig, err := signingHandler.AggregateSigs(pubKeysBitmap, epoch) if err != nil { - return err + log.Warn("total", "total", totalKey, "err", err) + return nil, err } - return nil + return aggSig, nil +} + +func extractValidatorPubKeys(validators []nodesCoordinator.Validator) []string { + pubKeys := make([]string, len(validators)) + for i, validator := range validators { + pubKeys[i] = string(validator.PubKey()) + } + return pubKeys } func (creator *blocksCreator) createLeaderSignature(header data.HeaderHandler, blsKeyBytes []byte) ([]byte, error) { @@ -267,3 +415,36 @@ func (creator *blocksCreator) createLeaderSignature(header data.HeaderHandler, b func (creator *blocksCreator) IsInterfaceNil() bool { return creator == nil } + +// GeneratePubKeyBitmap generates a []byte where the first `numOfOnes` bits are set to 1. +func GeneratePubKeyBitmap(numOfOnes int) []byte { + if numOfOnes <= 0 { + return nil // Handle invalid cases + } + + // calculate how many full bytes are needed + numBytes := (numOfOnes + 7) / 8 // Equivalent to ceil(numOfOnes / 8) + result := make([]byte, numBytes) + + // fill in the bytes + for i := 0; i < numBytes; i++ { + bitsLeft := numOfOnes - (i * 8) + if bitsLeft >= 8 { + result[i] = 0xFF // All 8 bits set to 1 (255 in decimal) + } else { + result[i] = byte((1 << bitsLeft) - 1) // Only set the needed bits + } + } + + return result +} + +// UnsetBitInBitmap will unset a bit from provided bit based on the provided index +func UnsetBitInBitmap(index int, bitmap []byte) error { + if index/8 >= len(bitmap) { + return common.ErrWrongSizeBitmap + } + bitmap[index/8] = bitmap[index/8] & ^(1 << uint8(index%8)) + + return nil +} diff --git a/node/chainSimulator/process/processor_test.go b/node/chainSimulator/process/processor_test.go index 84a93eea028..7dbd8464d1d 100644 --- a/node/chainSimulator/process/processor_test.go +++ b/node/chainSimulator/process/processor_test.go @@ -9,9 +9,10 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" - mockConsensus "github.com/multiversx/mx-chain-go/consensus/mock" "github.com/multiversx/mx-chain-go/factory" "github.com/multiversx/mx-chain-go/integrationTests/mock" "github.com/multiversx/mx-chain-go/node/chainSimulator/components/heartbeat" @@ -22,10 +23,10 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/chainSimulator" testsConsensus "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" testsFactory "github.com/multiversx/mx-chain-go/testscommon/factory" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/require" ) var expectedErr = errors.New("expected error") @@ -222,8 +223,8 @@ func TestBlocksCreator_CreateNewBlock(t *testing.T) { }, }, NodesCoord: &shardingMocks.NodesCoordinatorStub{ - ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return nil, expectedErr + ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return nil, nil, expectedErr }, }, } @@ -516,7 +517,7 @@ func TestBlocksCreator_CreateNewBlock(t *testing.T) { nodeHandler := getNodeHandler() nodeHandler.GetBroadcastMessengerCalled = func() consensus.BroadcastMessenger { - return &mockConsensus.BroadcastMessengerMock{ + return &testsConsensus.BroadcastMessengerMock{ BroadcastHeaderCalled: func(handler data.HeaderHandler, bytes []byte) error { return expectedErr }, @@ -576,6 +577,9 @@ func getNodeHandler() *chainSimulator.NodeHandlerMock { }, } }, + EnableEpochsHandlerCalled: func() common.EnableEpochsHandler { + return &enableEpochsHandlerMock.EnableEpochsHandlerStub{} + }, } }, GetProcessComponentsCalled: func() factory.ProcessComponentsHolder { @@ -597,10 +601,9 @@ func getNodeHandler() *chainSimulator.NodeHandlerMock { }, }, NodesCoord: &shardingMocks.NodesCoordinatorStub{ - ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{ - shardingMocks.NewValidatorMock([]byte("A"), 1, 1), - }, nil + ComputeConsensusGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + v := shardingMocks.NewValidatorMock([]byte("A"), 1, 1) + return v, []nodesCoordinator.Validator{v}, nil }, }, } @@ -626,7 +629,30 @@ func getNodeHandler() *chainSimulator.NodeHandlerMock { } }, GetBroadcastMessengerCalled: func() consensus.BroadcastMessenger { - return &mockConsensus.BroadcastMessengerMock{} + return &testsConsensus.BroadcastMessengerMock{} }, } } + +func TestGeneratePubKeyBitmap(t *testing.T) { + t.Parallel() + + require.Equal(t, []byte{1}, chainSimulatorProcess.GeneratePubKeyBitmap(1)) + require.Equal(t, []byte{3}, chainSimulatorProcess.GeneratePubKeyBitmap(2)) + require.Equal(t, []byte{7}, chainSimulatorProcess.GeneratePubKeyBitmap(3)) + require.Equal(t, []byte{255, 255, 15}, chainSimulatorProcess.GeneratePubKeyBitmap(20)) + + bitmap := chainSimulatorProcess.GeneratePubKeyBitmap(2) + _ = chainSimulatorProcess.UnsetBitInBitmap(0, bitmap) + require.Equal(t, []byte{2}, bitmap) + + bitmap = chainSimulatorProcess.GeneratePubKeyBitmap(20) + _ = chainSimulatorProcess.UnsetBitInBitmap(3, bitmap) + require.Equal(t, []byte{247, 255, 15}, bitmap) + + err := chainSimulatorProcess.UnsetBitInBitmap(3, nil) + require.Equal(t, common.ErrWrongSizeBitmap, err) + + err = chainSimulatorProcess.UnsetBitInBitmap(3, []byte{}) + require.Equal(t, common.ErrWrongSizeBitmap, err) +} diff --git a/node/external/blockAPI/apiBlockFactory_test.go b/node/external/blockAPI/apiBlockFactory_test.go index 679aa6c0e1a..43e41a65173 100644 --- a/node/external/blockAPI/apiBlockFactory_test.go +++ b/node/external/blockAPI/apiBlockFactory_test.go @@ -15,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-go/process/txstatus" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + dataRetrieverTestCommon "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" @@ -25,6 +26,10 @@ import ( func createMockArgsAPIBlockProc() *ArgAPIBlockProcessor { statusComputer, _ := txstatus.NewStatusComputer(0, mock.NewNonceHashConverterMock(), &storageMocks.ChainStorerStub{}) + chainHandler := &testscommon.ChainHandlerMock{} + _ = chainHandler.SetCurrentBlockHeaderAndRootHash(&block.Header{ + Nonce: 123456, + }, []byte("root")) return &ArgAPIBlockProcessor{ Store: &storageMocks.ChainStorerStub{}, @@ -41,6 +46,8 @@ func createMockArgsAPIBlockProc() *ArgAPIBlockProcessor { AccountsRepository: &state.AccountsRepositoryStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + ProofsPool: &dataRetrieverTestCommon.ProofsPoolMock{}, + BlockChain: chainHandler, } } @@ -183,6 +190,25 @@ func TestCreateAPIBlockProcessorNilArgs(t *testing.T) { _, err := CreateAPIBlockProcessor(arguments) assert.Equal(t, errNilEnableEpochsHandler, err) }) + t.Run("NilProofsPool", func(t *testing.T) { + t.Parallel() + + arguments := createMockArgsAPIBlockProc() + arguments.ProofsPool = nil + + _, err := CreateAPIBlockProcessor(arguments) + assert.Equal(t, process.ErrNilProofsPool, err) + }) + + t.Run("NilBlockChain", func(t *testing.T) { + t.Parallel() + + arguments := createMockArgsAPIBlockProc() + arguments.BlockChain = nil + + _, err := CreateAPIBlockProcessor(arguments) + assert.Equal(t, process.ErrNilBlockChain, err) + }) } func TestGetBlockByHash_KeyNotFound(t *testing.T) { diff --git a/node/external/blockAPI/baseBlock.go b/node/external/blockAPI/baseBlock.go index df637f338d6..d4ce11d0a92 100644 --- a/node/external/blockAPI/baseBlock.go +++ b/node/external/blockAPI/baseBlock.go @@ -20,6 +20,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/api/shared/logging" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -28,7 +30,6 @@ import ( "github.com/multiversx/mx-chain-go/outport/process/alteredaccounts/shared" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" - logger "github.com/multiversx/mx-chain-logger-go" ) // BlockStatus is the status of a block @@ -59,6 +60,8 @@ type baseAPIBlockProcessor struct { accountsRepository state.AccountsRepository scheduledTxsExecutionHandler process.ScheduledTxsExecutionHandler enableEpochsHandler common.EnableEpochsHandler + proofsPool dataRetriever.ProofsPool + blockchain data.ChainHandler } var log = logger.GetOrCreate("node/blockAPI") @@ -601,3 +604,67 @@ func createAlteredBlockHash(hash []byte) []byte { return alteredHash } + +func (bap *baseAPIBlockProcessor) addProof( + headerHash []byte, + header data.HeaderHandler, + apiBlock *api.Block, +) error { + if !bap.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + return nil + } + + headerProof, err := bap.getHeaderProof(headerHash, header) + if err != nil { + return errCannotFindBlockProof + } + + apiBlock.PubKeyBitmap = hex.EncodeToString(headerProof.GetPubKeysBitmap()) + apiBlock.Signature = hex.EncodeToString(headerProof.GetAggregatedSignature()) + + apiBlock.Proof = proofToAPIProof(headerProof) + + return nil +} + +func (bap *baseAPIBlockProcessor) getHeaderProof( + headerHash []byte, + header data.HeaderHandler, +) (data.HeaderProofHandler, error) { + proofFromPool, err := bap.proofsPool.GetProof(header.GetShardID(), headerHash) + if err == nil { + return proofFromPool, nil + } + + proofsStorer, err := bap.store.GetStorer(dataRetriever.ProofsUnit) + if err != nil { + return nil, err + } + + proofBytes, err := proofsStorer.GetFromEpoch(headerHash, header.GetEpoch()) + if err != nil { + return nil, err + } + + proof := &block.HeaderProof{} + err = bap.marshalizer.Unmarshal(proof, proofBytes) + + return proof, err +} + +func (bap *baseAPIBlockProcessor) isBlockNonceInStorage(blockNonce uint64) bool { + return blockNonce <= bap.blockchain.GetCurrentBlockHeader().GetNonce() +} + +func proofToAPIProof(proof data.HeaderProofHandler) *api.HeaderProof { + return &api.HeaderProof{ + PubKeysBitmap: hex.EncodeToString(proof.GetPubKeysBitmap()), + AggregatedSignature: hex.EncodeToString(proof.GetAggregatedSignature()), + HeaderHash: hex.EncodeToString(proof.GetHeaderHash()), + HeaderEpoch: proof.GetHeaderEpoch(), + HeaderNonce: proof.GetHeaderNonce(), + HeaderShardId: proof.GetHeaderShardId(), + HeaderRound: proof.GetHeaderRound(), + IsStartOfEpoch: proof.GetIsStartOfEpoch(), + } +} diff --git a/node/external/blockAPI/baseBlock_test.go b/node/external/blockAPI/baseBlock_test.go index 9518883166b..8f0b2b9ed63 100644 --- a/node/external/blockAPI/baseBlock_test.go +++ b/node/external/blockAPI/baseBlock_test.go @@ -4,9 +4,11 @@ import ( "bytes" "encoding/hex" "errors" + "math/big" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/api" "github.com/multiversx/mx-chain-core-go/data/block" @@ -20,7 +22,9 @@ import ( "github.com/multiversx/mx-chain-go/node/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + dataRetrieverTestsCommon "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" storageMocks "github.com/multiversx/mx-chain-go/testscommon/storage" @@ -42,6 +46,7 @@ func createBaseBlockProcessor() *baseAPIBlockProcessor { apiTransactionHandler: &mock.TransactionAPIHandlerStub{}, logsFacade: &testscommon.LogsFacadeStub{}, receiptsRepository: &testscommon.ReceiptsRepositoryStub{}, + enableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } } @@ -410,6 +415,166 @@ func TestAddScheduledInfoInBlock(t *testing.T) { }, apiBlock) } +func TestProofToAPIProof(t *testing.T) { + t.Parallel() + + headerProof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderEpoch: 1, + HeaderNonce: 3, + HeaderShardId: 2, + HeaderRound: 4, + IsStartOfEpoch: true, + } + + proofToAPIProof(headerProof) + require.Equal(t, &api.HeaderProof{ + PubKeysBitmap: hex.EncodeToString(headerProof.PubKeysBitmap), + AggregatedSignature: hex.EncodeToString(headerProof.AggregatedSignature), + HeaderHash: hex.EncodeToString(headerProof.HeaderHash), + HeaderEpoch: 1, + HeaderNonce: 3, + HeaderShardId: 2, + HeaderRound: 4, + IsStartOfEpoch: true, + }, proofToAPIProof(headerProof)) +} + +func TestAddProof(t *testing.T) { + t.Parallel() + + t.Run("no proof for required block should error", func(t *testing.T) { + t.Parallel() + + baseAPIBlockProc := createBaseBlockProcessor() + baseAPIBlockProc.proofsPool = &dataRetrieverTestsCommon.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + return nil, errors.New("error") + }, + } + baseAPIBlockProc.store = &storageMocks.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + return nil, errors.New("error") + }, + } + baseAPIBlockProc.enableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + + header := &block.HeaderV2{} + + err := baseAPIBlockProc.addProof([]byte("hash"), header, &api.Block{}) + require.Equal(t, errCannotFindBlockProof, err) + }) + + t.Run("proof for current block returned from pool", func(t *testing.T) { + t.Parallel() + + baseAPIBlockProc := createBaseBlockProcessor() + baseAPIBlockProc.proofsPool = &dataRetrieverTestsCommon.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + return &block.HeaderProof{ + HeaderHash: []byte("hash2"), + }, nil + }, + } + baseAPIBlockProc.enableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + + header := &block.HeaderV2{} + + apiBlock := &api.Block{} + err := baseAPIBlockProc.addProof([]byte("hash"), header, apiBlock) + require.Nil(t, err) + + require.Equal(t, &api.HeaderProof{ + HeaderHash: hex.EncodeToString([]byte("hash2")), + }, apiBlock.Proof) + }) + + t.Run("no previous proof only current proof", func(t *testing.T) { + t.Parallel() + + baseAPIBlockProc := createBaseBlockProcessor() + baseAPIBlockProc.proofsPool = &dataRetrieverTestsCommon.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + return &block.HeaderProof{ + HeaderHash: []byte("hash2"), + }, nil + }, + } + baseAPIBlockProc.enableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + + baseAPIBlockProc.enableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + header := &block.HeaderV2{} + + apiBlock := &api.Block{} + err := baseAPIBlockProc.addProof([]byte("hash"), header, apiBlock) + require.Nil(t, err) + + require.Equal(t, &api.HeaderProof{ + HeaderHash: hex.EncodeToString([]byte("hash2")), + }, apiBlock.Proof) + }) + + t.Run("proof for block returned from storage", func(t *testing.T) { + t.Parallel() + + baseAPIBlockProc := createBaseBlockProcessor() + baseAPIBlockProc.proofsPool = &dataRetrieverTestsCommon.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + return nil, errors.New("error") + }, + } + baseAPIBlockProc.enableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + + proof := &block.HeaderProof{ + HeaderHash: []byte("hash2"), + } + proofBytes, err := baseAPIBlockProc.marshalizer.Marshal(proof) + require.Nil(t, err) + + baseAPIBlockProc.store = &storageMocks.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + return &storageMocks.StorerStub{ + GetFromEpochCalled: func(key []byte, epoch uint32) ([]byte, error) { + return proofBytes, nil + }, + }, nil + }, + } + + header := &block.HeaderV2{} + + apiBlock := &api.Block{} + err = baseAPIBlockProc.addProof([]byte("hash"), header, apiBlock) + require.Nil(t, err) + + require.Equal(t, &api.HeaderProof{ + HeaderHash: hex.EncodeToString([]byte("hash2")), + }, apiBlock.Proof) + }) +} + func TestBigInToString(t *testing.T) { t.Parallel() diff --git a/node/external/blockAPI/blockArgs.go b/node/external/blockAPI/blockArgs.go index cd0836d6546..4646ed1f2f8 100644 --- a/node/external/blockAPI/blockArgs.go +++ b/node/external/blockAPI/blockArgs.go @@ -2,6 +2,7 @@ package blockAPI import ( "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" @@ -31,4 +32,6 @@ type ArgAPIBlockProcessor struct { AccountsRepository state.AccountsRepository ScheduledTxsExecutionHandler process.ScheduledTxsExecutionHandler EnableEpochsHandler common.EnableEpochsHandler + ProofsPool dataRetriever.ProofsPool + BlockChain data.ChainHandler } diff --git a/node/external/blockAPI/check.go b/node/external/blockAPI/check.go index b17ddedf22b..c1e9e404a56 100644 --- a/node/external/blockAPI/check.go +++ b/node/external/blockAPI/check.go @@ -63,6 +63,13 @@ func checkNilArg(arg *ArgAPIBlockProcessor) error { if check.IfNil(arg.EnableEpochsHandler) { return errNilEnableEpochsHandler } + if check.IfNil(arg.ProofsPool) { + return process.ErrNilProofsPool + } + if check.IfNil(arg.BlockChain) { + return process.ErrNilBlockChain + } + return core.CheckHandlerCompatibility(arg.EnableEpochsHandler, []core.EnableEpochFlag{ common.RefactorPeersMiniBlocksFlag, }) diff --git a/node/external/blockAPI/errors.go b/node/external/blockAPI/errors.go index 38123355dfb..18c413917c3 100644 --- a/node/external/blockAPI/errors.go +++ b/node/external/blockAPI/errors.go @@ -18,3 +18,5 @@ var errCannotUnmarshalTransactions = errors.New("cannot unmarshal transaction(s) var errCannotLoadReceipts = errors.New("cannot load receipt(s)") var errCannotUnmarshalReceipts = errors.New("cannot unmarshal receipt(s)") var errUnknownBlockRequestType = errors.New("unknown block request type") +var errCannotFindBlockProof = errors.New("cannot find block proof") +var errBlockNotFound = errors.New("block not found") diff --git a/node/external/blockAPI/internalBlock.go b/node/external/blockAPI/internalBlock.go index 7ee37bede33..7e737448b9d 100644 --- a/node/external/blockAPI/internalBlock.go +++ b/node/external/blockAPI/internalBlock.go @@ -42,7 +42,7 @@ func (ibp *internalBlockProcessor) GetInternalShardBlockByNonce(format common.Ap return nil, ErrShardOnlyEndpoint } - storerUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(ibp.selfShardID) + storerUnit := dataRetriever.GetHdrNonceHashDataUnit(ibp.selfShardID) nonceToByteSlice := ibp.uint64ByteSliceConverter.ToByteSlice(nonce) headerHash, err := ibp.store.Get(storerUnit, nonceToByteSlice) diff --git a/node/external/blockAPI/metaBlock.go b/node/external/blockAPI/metaBlock.go index 820ebb4ad3c..582acf4613d 100644 --- a/node/external/blockAPI/metaBlock.go +++ b/node/external/blockAPI/metaBlock.go @@ -39,18 +39,33 @@ func newMetaApiBlockProcessor(arg *ArgAPIBlockProcessor, emptyReceiptsHash []byt accountsRepository: arg.AccountsRepository, scheduledTxsExecutionHandler: arg.ScheduledTxsExecutionHandler, enableEpochsHandler: arg.EnableEpochsHandler, + proofsPool: arg.ProofsPool, + blockchain: arg.BlockChain, }, } } // GetBlockByNonce wil return a meta APIBlock by nonce func (mbp *metaAPIBlockProcessor) GetBlockByNonce(nonce uint64, options api.BlockQueryOptions) (*api.Block, error) { + if !mbp.isBlockNonceInStorage(nonce) { + return nil, errBlockNotFound + } + + headerHash, blockBytes, err := mbp.getBlockHashAndBytesByNonce(nonce) + if err != nil { + return nil, err + } + + return mbp.convertMetaBlockBytesToAPIBlock(headerHash, blockBytes, options) +} + +func (mbp *metaAPIBlockProcessor) getBlockHashAndBytesByNonce(nonce uint64) ([]byte, []byte, error) { storerUnit := dataRetriever.MetaHdrNonceHashDataUnit nonceToByteSlice := mbp.uint64ByteSliceConverter.ToByteSlice(nonce) headerHash, err := mbp.store.Get(storerUnit, nonceToByteSlice) if err != nil { - return nil, err + return nil, nil, err } // if genesis block, get the nonce key corresponding to the altered block @@ -60,15 +75,15 @@ func (mbp *metaAPIBlockProcessor) GetBlockByNonce(nonce uint64, options api.Bloc alteredHeaderHash, err := mbp.store.Get(storerUnit, nonceToByteSlice) if err != nil { - return nil, err + return nil, nil, err } blockBytes, err := mbp.getFromStorer(dataRetriever.MetaBlockUnit, alteredHeaderHash) if err != nil { - return nil, err + return nil, nil, err } - return mbp.convertMetaBlockBytesToAPIBlock(headerHash, blockBytes, options) + return headerHash, blockBytes, nil } // GetBlockByHash will return a meta APIBlock by hash @@ -242,6 +257,11 @@ func (mbp *metaAPIBlockProcessor) convertMetaBlockBytesToAPIBlock(hash []byte, b addScheduledInfoInBlock(blockHeader, apiMetaBlock) addStartOfEpochInfoInBlock(blockHeader, apiMetaBlock) + err = mbp.addProof(hash, blockHeader, apiMetaBlock) + if err != nil { + return nil, err + } + return apiMetaBlock, nil } diff --git a/node/external/blockAPI/metaBlock_test.go b/node/external/blockAPI/metaBlock_test.go index 3d73d2c2daa..343abd17ba8 100644 --- a/node/external/blockAPI/metaBlock_test.go +++ b/node/external/blockAPI/metaBlock_test.go @@ -19,7 +19,9 @@ import ( "github.com/multiversx/mx-chain-go/outport/process/alteredaccounts/shared" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + dataRetrieverTestsCommon "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/state" @@ -34,6 +36,11 @@ func createMockMetaAPIProcessor( withHistory bool, withKey bool, ) *metaAPIBlockProcessor { + chainHandler := &testscommon.ChainHandlerMock{} + _ = chainHandler.SetCurrentBlockHeaderAndRootHash(&block.Header{ + Nonce: 123456, + }, []byte("root")) + return newMetaApiBlockProcessor(&ArgAPIBlockProcessor{ APITransactionHandler: &mock.TransactionAPIHandlerStub{}, SelfShardID: core.MetachainShardId, @@ -63,6 +70,9 @@ func createMockMetaAPIProcessor( AlteredAccountsProvider: &testscommon.AlteredAccountsProviderStub{}, AccountsRepository: &state.AccountsRepositoryStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, + ProofsPool: &dataRetrieverTestsCommon.ProofsPoolMock{}, + EnableEpochsHandler: enableEpochsHandlerMock.NewEnableEpochsHandlerStubWithNoFlagsDefined(), + BlockChain: chainHandler, }, nil) } @@ -85,6 +95,29 @@ func TestMetaAPIBlockProcessor_GetBlockByHashInvalidHashShouldErr(t *testing.T) assert.Error(t, err) } +func TestMetaAPIBlockProcessor_BlockByNonceNonceTooHighShouldErr(t *testing.T) { + t.Parallel() + + epoch := uint32(0) + headerHash := []byte("d08089f2ab739520598fd7aeed08c427460fe94f286383047f3f61951afc4e00") + + storerMock := genericMocks.NewStorerMockWithEpoch(epoch) + + blockProc := createMockMetaAPIProcessor( + headerHash, + storerMock, + true, + true, + ) + blockProc.blockchain = &testscommon.ChainHandlerMock{} + err := blockProc.blockchain.SetCurrentBlockHeaderAndRootHash(&block.MetaBlock{Nonce: 10}, []byte("root")) + require.NoError(t, err) + + res, err := blockProc.GetBlockByNonce(11, api.BlockQueryOptions{}) + require.Nil(t, res) + require.Equal(t, errBlockNotFound, err) +} + func TestMetaAPIBlockProcessor_GetBlockByNonceInvalidNonceShouldErr(t *testing.T) { t.Parallel() diff --git a/node/external/blockAPI/shardBlock.go b/node/external/blockAPI/shardBlock.go index 1417336658f..9487f0498a0 100644 --- a/node/external/blockAPI/shardBlock.go +++ b/node/external/blockAPI/shardBlock.go @@ -40,18 +40,33 @@ func newShardApiBlockProcessor(arg *ArgAPIBlockProcessor, emptyReceiptsHash []by accountsRepository: arg.AccountsRepository, scheduledTxsExecutionHandler: arg.ScheduledTxsExecutionHandler, enableEpochsHandler: arg.EnableEpochsHandler, + proofsPool: arg.ProofsPool, + blockchain: arg.BlockChain, }, } } // GetBlockByNonce will return a shard APIBlock by nonce func (sbp *shardAPIBlockProcessor) GetBlockByNonce(nonce uint64, options api.BlockQueryOptions) (*api.Block, error) { - storerUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(sbp.selfShardID) + if !sbp.isBlockNonceInStorage(nonce) { + return nil, errBlockNotFound + } + + headerHash, blockBytes, err := sbp.getBlockHashAndBytesByNonce(nonce) + if err != nil { + return nil, err + } + + return sbp.convertShardBlockBytesToAPIBlock(headerHash, blockBytes, options) +} + +func (sbp *shardAPIBlockProcessor) getBlockHashAndBytesByNonce(nonce uint64) ([]byte, []byte, error) { + storerUnit := dataRetriever.GetHdrNonceHashDataUnit(sbp.selfShardID) nonceToByteSlice := sbp.uint64ByteSliceConverter.ToByteSlice(nonce) headerHash, err := sbp.store.Get(storerUnit, nonceToByteSlice) if err != nil { - return nil, err + return nil, nil, err } // if genesis block, get the nonce key corresponding to the altered block @@ -61,15 +76,15 @@ func (sbp *shardAPIBlockProcessor) GetBlockByNonce(nonce uint64, options api.Blo alteredHeaderHash, err := sbp.store.Get(storerUnit, nonceToByteSlice) if err != nil { - return nil, err + return nil, nil, err } blockBytes, err := sbp.getFromStorer(dataRetriever.BlockHeaderUnit, alteredHeaderHash) if err != nil { - return nil, err + return nil, nil, err } - return sbp.convertShardBlockBytesToAPIBlock(headerHash, blockBytes, options) + return headerHash, blockBytes, nil } // GetBlockByHash will return a shard APIBlock by hash @@ -98,7 +113,7 @@ func (sbp *shardAPIBlockProcessor) GetBlockByHash(hash []byte, options api.Block return nil, err } - storerUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(sbp.selfShardID) + storerUnit := dataRetriever.GetHdrNonceHashDataUnit(sbp.selfShardID) return sbp.computeStatusAndPutInBlock(blockAPI, storerUnit) } @@ -145,7 +160,7 @@ func (sbp *shardAPIBlockProcessor) getHashAndBlockBytesFromStorerByHash(params a } func (sbp *shardAPIBlockProcessor) getHashAndBlockBytesFromStorerByNonce(params api.GetBlockParameters) ([]byte, []byte, error) { - storerUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(sbp.selfShardID) + storerUnit := dataRetriever.GetHdrNonceHashDataUnit(sbp.selfShardID) nonceToByteSlice := sbp.uint64ByteSliceConverter.ToByteSlice(params.Nonce) headerHash, err := sbp.store.Get(storerUnit, nonceToByteSlice) @@ -231,6 +246,10 @@ func (sbp *shardAPIBlockProcessor) convertShardBlockBytesToAPIBlock(hash []byte, } addScheduledInfoInBlock(blockHeader, apiBlock) + err = sbp.addProof(hash, blockHeader, apiBlock) + if err != nil { + return nil, err + } return apiBlock, nil } diff --git a/node/external/blockAPI/shardBlock_test.go b/node/external/blockAPI/shardBlock_test.go index fd57c180430..480f18b0120 100644 --- a/node/external/blockAPI/shardBlock_test.go +++ b/node/external/blockAPI/shardBlock_test.go @@ -17,7 +17,9 @@ import ( "github.com/multiversx/mx-chain-go/outport/process/alteredaccounts/shared" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + dataRetrieverTestsCommon "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/state" @@ -33,6 +35,11 @@ func createMockShardAPIProcessor( withHistory bool, withKey bool, ) *shardAPIBlockProcessor { + chainHandler := &testscommon.ChainHandlerMock{} + _ = chainHandler.SetCurrentBlockHeaderAndRootHash(&block.Header{ + Nonce: 123456, + }, []byte("root")) + return newShardApiBlockProcessor(&ArgAPIBlockProcessor{ APITransactionHandler: &mock.TransactionAPIHandlerStub{}, SelfShardID: shardID, @@ -62,6 +69,9 @@ func createMockShardAPIProcessor( AlteredAccountsProvider: &testscommon.AlteredAccountsProviderStub{}, AccountsRepository: &state.AccountsRepositoryStub{}, ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, + ProofsPool: &dataRetrieverTestsCommon.ProofsPoolMock{}, + EnableEpochsHandler: enableEpochsHandlerMock.NewEnableEpochsHandlerStubWithNoFlagsDefined(), + BlockChain: chainHandler, }, nil) } @@ -107,6 +117,31 @@ func TestShardAPIBlockProcessor_GetBlockByNonceInvalidNonceShouldErr(t *testing. assert.Error(t, err) } +func TestShardAPIBlockProcessor_BlockByNonceNonceTooHighShouldErr(t *testing.T) { + t.Parallel() + + epoch := uint32(0) + shardID := uint32(3) + headerHash := []byte("d08089f2ab739520598fd7aeed08c427460fe94f286383047f3f61951afc4e00") + + storerMock := genericMocks.NewStorerMockWithEpoch(epoch) + + blockProc := createMockShardAPIProcessor( + shardID, + headerHash, + storerMock, + true, + true, + ) + blockProc.blockchain = &testscommon.ChainHandlerMock{} + err := blockProc.blockchain.SetCurrentBlockHeaderAndRootHash(&block.Header{Nonce: 10}, []byte("root")) + require.NoError(t, err) + + res, err := blockProc.GetBlockByNonce(11, api.BlockQueryOptions{}) + require.Nil(t, res) + require.Equal(t, errBlockNotFound, err) +} + func TestShardAPIBlockProcessor_GetBlockByRoundInvalidRoundShouldErr(t *testing.T) { t.Parallel() diff --git a/node/interface.go b/node/interface.go index 23a706ed25a..05330285fb6 100644 --- a/node/interface.go +++ b/node/interface.go @@ -2,12 +2,11 @@ package node import ( "io" - "time" "github.com/multiversx/mx-chain-core-go/core" - "github.com/multiversx/mx-chain-go/p2p" - "github.com/multiversx/mx-chain-go/update" vmcommon "github.com/multiversx/mx-chain-vm-common-go" + + "github.com/multiversx/mx-chain-go/update" ) // NetworkShardingCollector defines the updating methods used by the network sharding component @@ -19,18 +18,6 @@ type NetworkShardingCollector interface { IsInterfaceNil() bool } -// P2PAntifloodHandler defines the behavior of a component able to signal that the system is too busy (or flooded) processing -// p2p messages -type P2PAntifloodHandler interface { - CanProcessMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error - CanProcessMessagesOnTopic(peer core.PeerID, topic string, numMessages uint32, totalSize uint64, sequence []byte) error - ResetForTopic(topic string) - SetMaxMessagesForTopic(topic string, maxNum uint32) - ApplyConsensusSize(size int) - BlacklistPeer(peer core.PeerID, reason string, duration time.Duration) - IsInterfaceNil() bool -} - // HardforkTrigger defines the behavior of a hardfork trigger type HardforkTrigger interface { SetExportFactoryHandler(exportFactoryHandler update.ExportFactoryHandler) error diff --git a/node/metrics/metrics.go b/node/metrics/metrics.go index ebb2a6bbe30..197ae930afd 100644 --- a/node/metrics/metrics.go +++ b/node/metrics/metrics.go @@ -204,6 +204,7 @@ func InitConfigMetrics( appStatusHandler.SetUInt64Value(common.MetricFixRelayedMoveBalanceToNonPayableSCEnableEpoch, uint64(enableEpochs.FixRelayedMoveBalanceToNonPayableSCEnableEpoch)) appStatusHandler.SetUInt64Value(common.MetricRelayedTransactionsV3EnableEpoch, uint64(enableEpochs.RelayedTransactionsV3EnableEpoch)) appStatusHandler.SetUInt64Value(common.MetricRelayedTransactionsV3FixESDTTransferEnableEpoch, uint64(enableEpochs.RelayedTransactionsV3FixESDTTransferEnableEpoch)) + appStatusHandler.SetUInt64Value(common.MetricCheckBuiltInCallOnTransferValueAndFailEnableRound, uint64(enableEpochs.CheckBuiltInCallOnTransferValueAndFailEnableRound)) appStatusHandler.SetUInt64Value(common.MetricMaskVMInternalDependenciesErrorsEnableEpoch, uint64(enableEpochs.MaskVMInternalDependenciesErrorsEnableEpoch)) appStatusHandler.SetUInt64Value(common.MetricFixBackTransferOPCODEEnableEpoch, uint64(enableEpochs.FixBackTransferOPCODEEnableEpoch)) appStatusHandler.SetUInt64Value(common.MetricValidationOnGobDecodeEnableEpoch, uint64(enableEpochs.ValidationOnGobDecodeEnableEpoch)) @@ -248,17 +249,17 @@ func InitRatingsMetrics(appStatusHandler core.AppStatusHandler, ratingsConfig co } appStatusHandler.SetUInt64Value(common.MetricRatingsGeneralSelectionChances+"_count", uint64(len(ratingsConfig.General.SelectionChances))) - appStatusHandler.SetUInt64Value(common.MetricRatingsShardChainHoursToMaxRatingFromStartRating, uint64(ratingsConfig.ShardChain.HoursToMaxRatingFromStartRating)) - appStatusHandler.SetStringValue(common.MetricRatingsShardChainProposerValidatorImportance, fmt.Sprintf("%f", ratingsConfig.ShardChain.ProposerValidatorImportance)) - appStatusHandler.SetStringValue(common.MetricRatingsShardChainProposerDecreaseFactor, fmt.Sprintf("%f", ratingsConfig.ShardChain.ProposerDecreaseFactor)) - appStatusHandler.SetStringValue(common.MetricRatingsShardChainValidatorDecreaseFactor, fmt.Sprintf("%f", ratingsConfig.ShardChain.ValidatorDecreaseFactor)) - appStatusHandler.SetStringValue(common.MetricRatingsShardChainConsecutiveMissedBlocksPenalty, fmt.Sprintf("%f", ratingsConfig.ShardChain.ConsecutiveMissedBlocksPenalty)) - - appStatusHandler.SetUInt64Value(common.MetricRatingsMetaChainHoursToMaxRatingFromStartRating, uint64(ratingsConfig.MetaChain.HoursToMaxRatingFromStartRating)) - appStatusHandler.SetStringValue(common.MetricRatingsMetaChainProposerValidatorImportance, fmt.Sprintf("%f", ratingsConfig.MetaChain.ProposerValidatorImportance)) - appStatusHandler.SetStringValue(common.MetricRatingsMetaChainProposerDecreaseFactor, fmt.Sprintf("%f", ratingsConfig.MetaChain.ProposerDecreaseFactor)) - appStatusHandler.SetStringValue(common.MetricRatingsMetaChainValidatorDecreaseFactor, fmt.Sprintf("%f", ratingsConfig.MetaChain.ValidatorDecreaseFactor)) - appStatusHandler.SetStringValue(common.MetricRatingsMetaChainConsecutiveMissedBlocksPenalty, fmt.Sprintf("%f", ratingsConfig.MetaChain.ConsecutiveMissedBlocksPenalty)) + appStatusHandler.SetUInt64Value(common.MetricRatingsShardChainHoursToMaxRatingFromStartRating, uint64(ratingsConfig.ShardChain.RatingStepsByEpoch[0].HoursToMaxRatingFromStartRating)) + appStatusHandler.SetStringValue(common.MetricRatingsShardChainProposerValidatorImportance, fmt.Sprintf("%f", ratingsConfig.ShardChain.RatingStepsByEpoch[0].ProposerValidatorImportance)) + appStatusHandler.SetStringValue(common.MetricRatingsShardChainProposerDecreaseFactor, fmt.Sprintf("%f", ratingsConfig.ShardChain.RatingStepsByEpoch[0].ProposerDecreaseFactor)) + appStatusHandler.SetStringValue(common.MetricRatingsShardChainValidatorDecreaseFactor, fmt.Sprintf("%f", ratingsConfig.ShardChain.RatingStepsByEpoch[0].ValidatorDecreaseFactor)) + appStatusHandler.SetStringValue(common.MetricRatingsShardChainConsecutiveMissedBlocksPenalty, fmt.Sprintf("%f", ratingsConfig.ShardChain.RatingStepsByEpoch[0].ConsecutiveMissedBlocksPenalty)) + + appStatusHandler.SetUInt64Value(common.MetricRatingsMetaChainHoursToMaxRatingFromStartRating, uint64(ratingsConfig.MetaChain.RatingStepsByEpoch[0].HoursToMaxRatingFromStartRating)) + appStatusHandler.SetStringValue(common.MetricRatingsMetaChainProposerValidatorImportance, fmt.Sprintf("%f", ratingsConfig.MetaChain.RatingStepsByEpoch[0].ProposerValidatorImportance)) + appStatusHandler.SetStringValue(common.MetricRatingsMetaChainProposerDecreaseFactor, fmt.Sprintf("%f", ratingsConfig.MetaChain.RatingStepsByEpoch[0].ProposerDecreaseFactor)) + appStatusHandler.SetStringValue(common.MetricRatingsMetaChainValidatorDecreaseFactor, fmt.Sprintf("%f", ratingsConfig.MetaChain.RatingStepsByEpoch[0].ValidatorDecreaseFactor)) + appStatusHandler.SetStringValue(common.MetricRatingsMetaChainConsecutiveMissedBlocksPenalty, fmt.Sprintf("%f", ratingsConfig.MetaChain.RatingStepsByEpoch[0].ConsecutiveMissedBlocksPenalty)) appStatusHandler.SetStringValue(common.MetricRatingsPeerHonestyDecayCoefficient, fmt.Sprintf("%f", ratingsConfig.PeerHonesty.DecayCoefficient)) appStatusHandler.SetUInt64Value(common.MetricRatingsPeerHonestyDecayUpdateIntervalInSeconds, uint64(ratingsConfig.PeerHonesty.DecayUpdateIntervalInSeconds)) diff --git a/node/metrics/metrics_test.go b/node/metrics/metrics_test.go index 2326bcfb535..8af33e8a334 100644 --- a/node/metrics/metrics_test.go +++ b/node/metrics/metrics_test.go @@ -213,11 +213,12 @@ func TestInitConfigMetrics(t *testing.T) { FixRelayedMoveBalanceToNonPayableSCEnableEpoch: 106, RelayedTransactionsV3EnableEpoch: 107, RelayedTransactionsV3FixESDTTransferEnableEpoch: 108, - MaskVMInternalDependenciesErrorsEnableEpoch: 109, - FixBackTransferOPCODEEnableEpoch: 110, - ValidationOnGobDecodeEnableEpoch: 111, - BarnardOpcodesEnableEpoch: 112, - AutomaticActivationOfNodesDisableEpoch: 108, + CheckBuiltInCallOnTransferValueAndFailEnableRound: 109, + MaskVMInternalDependenciesErrorsEnableEpoch: 110, + FixBackTransferOPCODEEnableEpoch: 111, + ValidationOnGobDecodeEnableEpoch: 112, + BarnardOpcodesEnableEpoch: 113, + AutomaticActivationOfNodesDisableEpoch: 114, MaxNodesChangeEnableEpoch: []config.MaxNodesChangeConfig{ { EpochEnable: 0, @@ -341,11 +342,12 @@ func TestInitConfigMetrics(t *testing.T) { "erd_fix_relayed_move_balance_to_non_payable_sc_enable_epoch": uint32(106), "erd_relayed_transactions_v3_enable_epoch": uint32(107), "erd_relayed_transactions_v3_fix_esdt_transfer_enable_epoch": uint32(108), - "erd_mask_vm_internal_dependencies_errors_enable_epoch": uint32(109), - "erd_fix_back_transfer_opcode_enable_epoch": uint32(110), - "erd_validation_on_gobdecode_enable_epoch": uint32(111), - "erd_barnard_opcodes_enable_epoch": uint32(112), - "erd_automatic_activation_of_nodes_disable_epoch": uint32(108), + "erd_checkbuiltincall_ontransfervalueandfail_enable_round": uint32(109), + "erd_mask_vm_internal_dependencies_errors_enable_epoch": uint32(110), + "erd_fix_back_transfer_opcode_enable_epoch": uint32(111), + "erd_validation_on_gobdecode_enable_epoch": uint32(112), + "erd_barnard_opcodes_enable_epoch": uint32(113), + "erd_automatic_activation_of_nodes_disable_epoch": uint32(114), "erd_max_nodes_change_enable_epoch": nil, "erd_total_supply": "12345", "erd_hysteresis": "0.100000", @@ -428,21 +430,25 @@ func TestInitRatingsMetrics(t *testing.T) { }, }, ShardChain: config.ShardChain{ - RatingSteps: config.RatingSteps{ - HoursToMaxRatingFromStartRating: 10, - ProposerValidatorImportance: 0.1, - ProposerDecreaseFactor: 0.1, - ValidatorDecreaseFactor: 0.1, - ConsecutiveMissedBlocksPenalty: 0.1, + RatingStepsByEpoch: []config.RatingSteps{ + { + HoursToMaxRatingFromStartRating: 10, + ProposerValidatorImportance: 0.1, + ProposerDecreaseFactor: 0.1, + ValidatorDecreaseFactor: 0.1, + ConsecutiveMissedBlocksPenalty: 0.1, + }, }, }, MetaChain: config.MetaChain{ - RatingSteps: config.RatingSteps{ - HoursToMaxRatingFromStartRating: 10, - ProposerValidatorImportance: 0.1, - ProposerDecreaseFactor: 0.1, - ValidatorDecreaseFactor: 0.1, - ConsecutiveMissedBlocksPenalty: 0.1, + RatingStepsByEpoch: []config.RatingSteps{ + { + HoursToMaxRatingFromStartRating: 10, + ProposerValidatorImportance: 0.1, + ProposerDecreaseFactor: 0.1, + ValidatorDecreaseFactor: 0.1, + ConsecutiveMissedBlocksPenalty: 0.1, + }, }, }, PeerHonesty: config.PeerHonestyConfig{ diff --git a/node/mock/epochStartNotifier.go b/node/mock/epochStartNotifier.go index c4675a37401..8c6fb4c51e8 100644 --- a/node/mock/epochStartNotifier.go +++ b/node/mock/epochStartNotifier.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/node/mock/factory/coreComponentsStub.go b/node/mock/factory/coreComponentsStub.go index 24c10b94a52..80bcc8ac16d 100644 --- a/node/mock/factory/coreComponentsStub.go +++ b/node/mock/factory/coreComponentsStub.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/factory" @@ -20,39 +21,43 @@ import ( // CoreComponentsMock - type CoreComponentsMock struct { - IntMarsh marshal.Marshalizer - TxMarsh marshal.Marshalizer - VmMarsh marshal.Marshalizer - Hash hashing.Hasher - TxSignHasherField hashing.Hasher - UInt64ByteSliceConv typeConverters.Uint64ByteSliceConverter - AddrPubKeyConv core.PubkeyConverter - ValPubKeyConv core.PubkeyConverter - PathHdl storage.PathManagerHandler - ChainIdCalled func() string - MinTransactionVersionCalled func() uint32 - WDTimer core.WatchdogTimer - Alarm core.TimersScheduler - NtpTimer ntp.SyncTimer - RoundChangeNotifier process.RoundNotifier - RoundHandlerField consensus.RoundHandler - EconomicsHandler process.EconomicsDataHandler - APIEconomicsHandler process.EconomicsDataHandler - RatingsConfig process.RatingsInfoHandler - RatingHandler sharding.PeerAccountListAndRatingHandler - NodesConfig sharding.GenesisNodesSetupHandler - EpochChangeNotifier process.EpochNotifier - EnableRoundsHandlerField process.EnableRoundsHandler - EpochNotifierWithConfirm factory.EpochStartNotifierWithConfirm - ChanStopProcess chan endProcess.ArgEndProcess - Shuffler nodesCoordinator.NodesShuffler - TxVersionCheckHandler process.TxVersionCheckerHandler - StartTime time.Time - NodeTypeProviderField core.NodeTypeProviderHandler - WasmVMChangeLockerInternal common.Locker - ProcessStatusHandlerInternal common.ProcessStatusHandler - HardforkTriggerPubKeyField []byte - EnableEpochsHandlerField common.EnableEpochsHandler + IntMarsh marshal.Marshalizer + TxMarsh marshal.Marshalizer + VmMarsh marshal.Marshalizer + Hash hashing.Hasher + TxSignHasherField hashing.Hasher + UInt64ByteSliceConv typeConverters.Uint64ByteSliceConverter + AddrPubKeyConv core.PubkeyConverter + ValPubKeyConv core.PubkeyConverter + PathHdl storage.PathManagerHandler + ChainIdCalled func() string + MinTransactionVersionCalled func() uint32 + WDTimer core.WatchdogTimer + Alarm core.TimersScheduler + NtpTimer ntp.SyncTimer + RoundChangeNotifier process.RoundNotifier + RoundHandlerField consensus.RoundHandler + EconomicsHandler process.EconomicsDataHandler + APIEconomicsHandler process.EconomicsDataHandler + RatingsConfig process.RatingsInfoHandler + RatingHandler sharding.PeerAccountListAndRatingHandler + NodesConfig sharding.GenesisNodesSetupHandler + EpochChangeNotifier process.EpochNotifier + EnableRoundsHandlerField process.EnableRoundsHandler + EpochNotifierWithConfirm factory.EpochStartNotifierWithConfirm + ChanStopProcess chan endProcess.ArgEndProcess + Shuffler nodesCoordinator.NodesShuffler + TxVersionCheckHandler process.TxVersionCheckerHandler + StartTime time.Time + NodeTypeProviderField core.NodeTypeProviderHandler + WasmVMChangeLockerInternal common.Locker + ProcessStatusHandlerInternal common.ProcessStatusHandler + HardforkTriggerPubKeyField []byte + EnableEpochsHandlerField common.EnableEpochsHandler + ChainParametersHandlerField process.ChainParametersHandler + ChainParametersSubscriberField process.ChainParametersSubscriber + FieldsSizeCheckerField common.FieldsSizeChecker + EpochChangeGracePeriodHandlerField common.EpochChangeGracePeriodHandler } // Create - @@ -258,6 +263,26 @@ func (ccm *CoreComponentsMock) EnableEpochsHandler() common.EnableEpochsHandler return ccm.EnableEpochsHandlerField } +// ChainParametersHandler - +func (ccm *CoreComponentsMock) ChainParametersHandler() process.ChainParametersHandler { + return ccm.ChainParametersHandlerField +} + +// ChainParametersSubscriber - +func (ccm *CoreComponentsMock) ChainParametersSubscriber() process.ChainParametersSubscriber { + return ccm.ChainParametersSubscriberField +} + +// FieldsSizeChecker - +func (ccm *CoreComponentsMock) FieldsSizeChecker() common.FieldsSizeChecker { + return ccm.FieldsSizeCheckerField +} + +// EpochChangeGracePeriodHandler - +func (ccm *CoreComponentsMock) EpochChangeGracePeriodHandler() common.EpochChangeGracePeriodHandler { + return ccm.EpochChangeGracePeriodHandlerField +} + // IsInterfaceNil - func (ccm *CoreComponentsMock) IsInterfaceNil() bool { return ccm == nil diff --git a/node/mock/forkDetectorMock.go b/node/mock/forkDetectorMock.go index d681b976d7d..7458887b48b 100644 --- a/node/mock/forkDetectorMock.go +++ b/node/mock/forkDetectorMock.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" ) @@ -19,6 +20,8 @@ type ForkDetectorMock struct { RestoreToGenesisCalled func() ResetProbableHighestNonceCalled func() SetFinalToLastCheckpointCalled func() + ReceivedProofCalled func(proof data.HeaderProofHandler) + AddCheckpointCalled func(nonce uint64, round uint64, hash []byte) } // RestoreToGenesis - @@ -80,6 +83,13 @@ func (fdm *ForkDetectorMock) ResetProbableHighestNonce() { } } +// AddCheckpoint - +func (fdm *ForkDetectorMock) AddCheckpoint(nonce uint64, round uint64, hash []byte) { + if fdm.AddCheckpointCalled != nil { + fdm.AddCheckpointCalled(nonce, round, hash) + } +} + // SetFinalToLastCheckpoint - func (fdm *ForkDetectorMock) SetFinalToLastCheckpoint() { if fdm.SetFinalToLastCheckpointCalled != nil { @@ -87,6 +97,13 @@ func (fdm *ForkDetectorMock) SetFinalToLastCheckpoint() { } } +// ReceivedProof - +func (fdm *ForkDetectorMock) ReceivedProof(proof data.HeaderProofHandler) { + if fdm.ReceivedProofCalled != nil { + fdm.ReceivedProofCalled(proof) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (fdm *ForkDetectorMock) IsInterfaceNil() bool { return fdm == nil diff --git a/node/mock/headerSigVerifierStub.go b/node/mock/headerSigVerifierStub.go deleted file mode 100644 index b75b5615a12..00000000000 --- a/node/mock/headerSigVerifierStub.go +++ /dev/null @@ -1,52 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-core-go/data" - -// HeaderSigVerifierStub - -type HeaderSigVerifierStub struct { - VerifyRandSeedAndLeaderSignatureCalled func(header data.HeaderHandler) error - VerifySignatureCalled func(header data.HeaderHandler) error - VerifyRandSeedCalled func(header data.HeaderHandler) error - VerifyLeaderSignatureCalled func(header data.HeaderHandler) error -} - -// VerifyRandSeed - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeed(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedCalled != nil { - return hsvm.VerifyRandSeedCalled(header) - } - - return nil -} - -// VerifyRandSeedAndLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedAndLeaderSignatureCalled != nil { - return hsvm.VerifyRandSeedAndLeaderSignatureCalled(header) - } - - return nil -} - -// VerifySignature - -func (hsvm *HeaderSigVerifierStub) VerifySignature(header data.HeaderHandler) error { - if hsvm.VerifySignatureCalled != nil { - return hsvm.VerifySignatureCalled(header) - } - - return nil -} - -// VerifyLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyLeaderSignatureCalled != nil { - return hsvm.VerifyLeaderSignatureCalled(header) - } - - return nil -} - -// IsInterfaceNil - -func (hsvm *HeaderSigVerifierStub) IsInterfaceNil() bool { - return hsvm == nil -} diff --git a/node/mock/p2pAntifloodHandlerStub.go b/node/mock/p2pAntifloodHandlerStub.go index bda3da406d5..92f7bafdc88 100644 --- a/node/mock/p2pAntifloodHandlerStub.go +++ b/node/mock/p2pAntifloodHandlerStub.go @@ -16,6 +16,7 @@ type P2PAntifloodHandlerStub struct { SetDebuggerCalled func(debugger process.AntifloodDebugger) error BlacklistPeerCalled func(peer core.PeerID, reason string, duration time.Duration) IsOriginatorEligibleForTopicCalled func(pid core.PeerID, topic string) error + SetConsensusSizeNotifierCalled func(chainParametersNotifier process.ChainParametersSubscriber, shardID uint32) } // CanProcessMessage - @@ -75,6 +76,13 @@ func (p2pahs *P2PAntifloodHandlerStub) SetMaxMessagesForTopic(_ string, _ uint32 } +// SetConsensusSizeNotifier - +func (p2pahs *P2PAntifloodHandlerStub) SetConsensusSizeNotifier(chainParametersNotifier process.ChainParametersSubscriber, shardID uint32) { + if p2pahs.SetConsensusSizeNotifierCalled != nil { + p2pahs.SetConsensusSizeNotifierCalled(chainParametersNotifier, shardID) + } +} + // SetPeerValidatorMapper - func (p2pahs *P2PAntifloodHandlerStub) SetPeerValidatorMapper(_ process.PeerValidatorMapper) error { return nil diff --git a/node/mock/rounderMock.go b/node/mock/rounderMock.go index a0723777fc1..f6d933fcbe1 100644 --- a/node/mock/rounderMock.go +++ b/node/mock/rounderMock.go @@ -24,6 +24,9 @@ func (rndm *RoundHandlerMock) BeforeGenesis() bool { return false } +// RevertOneRound - +func (rndm *RoundHandlerMock) RevertOneRound() {} + // Index - func (rndm *RoundHandlerMock) Index() int64 { if rndm.IndexCalled != nil { diff --git a/node/node.go b/node/node.go index 38b00841d2a..90c29f3c063 100644 --- a/node/node.go +++ b/node/node.go @@ -71,7 +71,6 @@ type accountInfo struct { type Node struct { initialNodesPubkeys map[uint32][]string roundDuration uint64 - consensusGroupSize int genesisTime time.Time peerDenialEvaluator p2p.PeerDenialEvaluator esdtStorageHandler vmcommon.ESDTNFTStorageHandler @@ -157,11 +156,6 @@ func (n *Node) CreateShardedStores() error { return nil } -// GetConsensusGroupSize returns the configured consensus size -func (n *Node) GetConsensusGroupSize() int { - return n.consensusGroupSize -} - // GetBalance gets the balance for a specific address func (n *Node) GetBalance(address string, options api.AccountQueryOptions) (*big.Int, api.BlockInfo, error) { userAccount, blockInfo, err := n.loadUserAccountHandlerByAddress(address, options) @@ -321,6 +315,43 @@ func (n *Node) GetKeyValuePairs(address string, options api.AccountQueryOptions, return mapToReturn, blockInfo, nil } +type userAccountWithLeavesParser interface { + GetLeavesParser() common.TrieLeafParser +} + +// IterateKeys starts from the given iteratorState and returns the next key-value pairs and the new iteratorState +func (n *Node) IterateKeys(address string, numKeys uint, iteratorState [][]byte, options api.AccountQueryOptions, ctx context.Context) (map[string]string, [][]byte, api.BlockInfo, error) { + userAccount, blockInfo, err := n.loadUserAccountHandlerByAddress(address, options) + if err != nil { + adaptedBlockInfo, isEmptyAccount := extractBlockInfoIfNewAccount(err) + if isEmptyAccount { + return make(map[string]string), nil, adaptedBlockInfo, nil + } + + return nil, nil, api.BlockInfo{}, err + } + + if check.IfNil(userAccount.DataTrie()) { + return map[string]string{}, nil, blockInfo, nil + } + + account, ok := userAccount.(userAccountWithLeavesParser) + if !ok { + return nil, nil, api.BlockInfo{}, fmt.Errorf("cannot cast user account to userAccountWithLeavesParser") + } + + if len(iteratorState) == 0 { + iteratorState = append(iteratorState, userAccount.GetRootHash()) + } + + mapToReturn, newIteratorState, err := n.stateComponents.TrieLeavesRetriever().GetLeaves(int(numKeys), iteratorState, account.GetLeavesParser(), ctx) + if err != nil { + return nil, nil, api.BlockInfo{}, err + } + + return mapToReturn, newIteratorState, blockInfo, nil +} + func (n *Node) getKeys(userAccount state.UserAccountHandler, ctx context.Context) (map[string]string, error) { chLeaves := &common.TrieIteratorChannels{ LeavesChan: make(chan core.KeyValueHolder, common.TrieLeavesChannelDefaultCapacity), diff --git a/node/nodeHelper.go b/node/nodeHelper.go index b1b5a27c816..2e0099cc3d6 100644 --- a/node/nodeHelper.go +++ b/node/nodeHelper.go @@ -65,11 +65,6 @@ func CreateNode( genesisTime := time.Unix(coreComponents.GenesisNodesSetup().GetStartTime(), 0) - consensusGroupSize, err := consensusComponents.ConsensusGroupSize() - if err != nil { - return nil, err - } - var nd *Node nd, err = NewNode( WithStatusCoreComponents(statusCoreComponents), @@ -85,7 +80,6 @@ func CreateNode( WithNetworkComponents(networkComponents), WithInitialNodesPubKeys(coreComponents.GenesisNodesSetup().InitialNodesPubKeys()), WithRoundDuration(coreComponents.GenesisNodesSetup().GetRoundDuration()), - WithConsensusGroupSize(consensusGroupSize), WithGenesisTime(genesisTime), WithConsensusType(config.Consensus.Type), WithBootstrapRoundIndex(bootstrapRoundIndex), diff --git a/node/nodeRunner.go b/node/nodeRunner.go index f86cb68f140..c6e362c6bad 100644 --- a/node/nodeRunner.go +++ b/node/nodeRunner.go @@ -20,6 +20,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core/throttler" "github.com/multiversx/mx-chain-core-go/data/endProcess" outportCore "github.com/multiversx/mx-chain-core-go/data/outport" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/api/gin" "github.com/multiversx/mx-chain-go/api/shared" "github.com/multiversx/mx-chain-go/common" @@ -61,7 +63,6 @@ import ( "github.com/multiversx/mx-chain-go/storage/storageunit" trieStatistics "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/multiversx/mx-chain-go/update/trigger" - logger "github.com/multiversx/mx-chain-logger-go" ) type nextOperationForNode int @@ -400,6 +401,7 @@ func (nr *nodeRunner) executeOneComponentCreationCycle( managedCoreComponents.EnableEpochsHandler(), managedDataComponents.Datapool().CurrentEpochValidatorInfo(), managedBootstrapComponents.NodesCoordinatorRegistryFactory(), + managedCoreComponents.ChainParametersHandler(), ) if err != nil { return true, err @@ -1476,7 +1478,7 @@ func (nr *nodeRunner) CreateManagedCoreComponents( ImportDbConfig: *nr.configs.ImportDbConfig, RatingsConfig: *nr.configs.RatingsConfig, EconomicsConfig: *nr.configs.EconomicsConfig, - NodesFilename: nr.configs.ConfigurationPathsHolder.Nodes, + NodesConfig: *nr.configs.NodesConfig, WorkingDirectory: nr.configs.FlagsConfig.DbDir, ChanStopNodeProcess: chanStopNodeProcess, } diff --git a/node/node_test.go b/node/node_test.go index bc77938989d..ee90fe2eaa6 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -50,6 +50,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/bootstrapMocks" + "github.com/multiversx/mx-chain-go/testscommon/consensus" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" @@ -4678,7 +4679,6 @@ func TestNode_Getters(t *testing.T) { heartbeatComponents := &factoryMock.HeartbeatV2ComponentsStub{} networkComponents := getDefaultNetworkComponents() processComponents := getDefaultProcessComponents() - consensusGroupSize := 10 n, err := node.NewNode( node.WithCoreComponents(coreComponents), @@ -4690,7 +4690,6 @@ func TestNode_Getters(t *testing.T) { node.WithHeartbeatV2Components(heartbeatComponents), node.WithNetworkComponents(networkComponents), node.WithProcessComponents(processComponents), - node.WithConsensusGroupSize(consensusGroupSize), node.WithImportMode(true), ) require.Nil(t, err) @@ -4705,7 +4704,6 @@ func TestNode_Getters(t *testing.T) { assert.True(t, n.GetHeartbeatV2Components() == heartbeatComponents) assert.True(t, n.GetNetworkComponents() == networkComponents) assert.True(t, n.GetProcessComponents() == processComponents) - assert.Equal(t, consensusGroupSize, n.GetConsensusGroupSize()) assert.True(t, n.IsInImportMode()) } @@ -5343,6 +5341,7 @@ func getDefaultCoreComponents() *nodeMockFactory.CoreComponentsMock { EpochChangeNotifier: &epochNotifier.EpochNotifierStub{}, TxVersionCheckHandler: versioning.NewTxVersionChecker(0), EnableEpochsHandlerField: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(common.RelayedTransactionsV3Flag), + FieldsSizeCheckerField: &testscommon.FieldsSizeCheckerMock{}, } } @@ -5362,7 +5361,7 @@ func getDefaultProcessComponents() *factoryMock.ProcessComponentsMock { BlockProcess: &testscommon.BlockProcessorStub{}, BlackListHdl: &testscommon.TimeCacheStub{}, BootSore: &mock.BootstrapStorerMock{}, - HeaderSigVerif: &mock.HeaderSigVerifierStub{}, + HeaderSigVerif: &consensus.HeaderSigVerifierMock{}, HeaderIntegrVerif: &mock.HeaderIntegrityVerifierStub{}, ValidatorStatistics: &testscommon.ValidatorStatisticsProcessorStub{}, ValidatorProvider: &stakingcommon.ValidatorsProviderStub{}, diff --git a/node/options.go b/node/options.go index 90385b3b8f4..f3ddc9bde10 100644 --- a/node/options.go +++ b/node/options.go @@ -210,18 +210,6 @@ func WithRoundDuration(roundDuration uint64) Option { } } -// WithConsensusGroupSize sets up the consensus group size option for the Node -func WithConsensusGroupSize(consensusGroupSize int) Option { - return func(n *Node) error { - if consensusGroupSize < 1 { - return ErrNegativeOrZeroConsensusGroupSize - } - log.Info("consensus group", "size", consensusGroupSize) - n.consensusGroupSize = consensusGroupSize - return nil - } -} - // WithGenesisTime sets up the genesis time option for the Node func WithGenesisTime(genesisTime time.Time) Option { return func(n *Node) error { diff --git a/node/options_test.go b/node/options_test.go index fa4a92ea449..1f565b58d34 100644 --- a/node/options_test.go +++ b/node/options_test.go @@ -71,32 +71,6 @@ func TestWithRoundDuration_ShouldWork(t *testing.T) { assert.Nil(t, err) } -func TestWithConsensusGroupSize_NegativeGroupSizeShouldErr(t *testing.T) { - t.Parallel() - - node, _ := NewNode() - - opt := WithConsensusGroupSize(-1) - err := opt(node) - - assert.Equal(t, 0, node.consensusGroupSize) - assert.Equal(t, ErrNegativeOrZeroConsensusGroupSize, err) -} - -func TestWithConsensusGroupSize_ShouldWork(t *testing.T) { - t.Parallel() - - node, _ := NewNode() - - groupSize := 567 - - opt := WithConsensusGroupSize(groupSize) - err := opt(node) - - assert.True(t, node.consensusGroupSize == groupSize) - assert.Nil(t, err) -} - func TestWithGenesisTime(t *testing.T) { t.Parallel() diff --git a/outport/errors.go b/outport/errors.go index 8c7ce22bb98..97a9a3047d7 100644 --- a/outport/errors.go +++ b/outport/errors.go @@ -11,9 +11,6 @@ var ErrNilArgsOutportFactory = errors.New("nil args outport driver factory") // ErrInvalidRetrialInterval signals that an invalid retrial interval was provided var ErrInvalidRetrialInterval = errors.New("invalid retrial interval") -// ErrNilPubKeyConverter signals that a nil pubkey converter has been provided -var ErrNilPubKeyConverter = errors.New("nil pub key converter") - var errNilSaveBlockArgs = errors.New("nil save blocks args provided") var errNilHeaderAndBodyArgs = errors.New("nil header and body args provided") diff --git a/outport/outport.go b/outport/outport.go index edcecc0691a..542af9ec493 100644 --- a/outport/outport.go +++ b/outport/outport.go @@ -6,6 +6,8 @@ import ( "sync/atomic" "time" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-core-go/core/check" outportcore "github.com/multiversx/mx-chain-core-go/data/outport" logger "github.com/multiversx/mx-chain-logger-go" @@ -84,6 +86,14 @@ func prepareBlockData( return nil, err } + var proof *block.HeaderProof + if !check.IfNil(headerBodyData.HeaderProof) { + proof, err = outportcore.GetHeaderProof(headerBodyData.HeaderProof) + if err != nil { + return nil, err + } + } + return &outportcore.BlockData{ ShardID: headerBodyData.Header.GetShardID(), HeaderBytes: headerBytes, @@ -91,6 +101,7 @@ func prepareBlockData( HeaderHash: headerBodyData.HeaderHash, Body: body, IntraShardMiniBlocks: headerBodyData.IntraShardMiniBlocks, + HeaderProof: proof, }, nil } diff --git a/outport/process/errors.go b/outport/process/errors.go index 70f13566dfb..853705659aa 100644 --- a/outport/process/errors.go +++ b/outport/process/errors.go @@ -60,8 +60,8 @@ var ErrNilMiniBlock = errors.New("nil miniBlock") // ErrNilExecutedTxHashes signals that a nil executed tx hashes map has been provided var ErrNilExecutedTxHashes = errors.New("nil executed tx hashes map") -// ErrNilOrderedTxHashes signals that a nil ordered tx list has been provided -var ErrNilOrderedTxHashes = errors.New("nil ordered tx list") - // ErrIndexOutOfBounds signals that an index is out of bounds var ErrIndexOutOfBounds = errors.New("index out of bounds") + +// ErrNilProofsPool signals that a nil proofs pool was used +var ErrNilProofsPool = errors.New("nil proofs pool") diff --git a/outport/process/factory/check.go b/outport/process/factory/check.go index ce224969cc2..426bfe19f50 100644 --- a/outport/process/factory/check.go +++ b/outport/process/factory/check.go @@ -50,6 +50,9 @@ func checkArgOutportDataProviderFactory(arg ArgOutportDataProviderFactory) error if check.IfNil(arg.ExecutionOrderGetter) { return process.ErrNilExecutionOrderGetter } + if check.IfNil(arg.ProofsPool) { + return process.ErrNilProofsPool + } return nil } diff --git a/outport/process/factory/check_test.go b/outport/process/factory/check_test.go index 513a3c7305b..18406020573 100644 --- a/outport/process/factory/check_test.go +++ b/outport/process/factory/check_test.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-go/outport/process/transactionsfee" "github.com/multiversx/mx-chain-go/testscommon" commonMocks "github.com/multiversx/mx-chain-go/testscommon/common" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" @@ -34,6 +35,7 @@ func createArgOutportDataProviderFactory() ArgOutportDataProviderFactory { MbsStorer: &genericMocks.StorerMock{}, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, ExecutionOrderGetter: &commonMocks.TxExecutionOrderHandlerStub{}, + ProofsPool: &dataRetriever.ProofsPoolMock{}, } } @@ -84,6 +86,10 @@ func TestCheckArgCreateOutportDataProvider(t *testing.T) { arg.Hasher = nil require.Equal(t, process.ErrNilHasher, checkArgOutportDataProviderFactory(arg)) + arg = createArgOutportDataProviderFactory() + arg.ProofsPool = nil + require.Equal(t, process.ErrNilProofsPool, checkArgOutportDataProviderFactory(arg)) + arg = createArgOutportDataProviderFactory() require.Nil(t, checkArgOutportDataProviderFactory(arg)) } diff --git a/outport/process/factory/outportDataProviderFactory.go b/outport/process/factory/outportDataProviderFactory.go index 5bb2c698136..28c92c7f732 100644 --- a/outport/process/factory/outportDataProviderFactory.go +++ b/outport/process/factory/outportDataProviderFactory.go @@ -5,6 +5,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/outport" "github.com/multiversx/mx-chain-go/outport/process" "github.com/multiversx/mx-chain-go/outport/process/alteredaccounts" @@ -37,6 +38,7 @@ type ArgOutportDataProviderFactory struct { MbsStorer storage.Storer EnableEpochsHandler common.EnableEpochsHandler ExecutionOrderGetter common.ExecutionOrderGetter + ProofsPool dataRetriever.ProofsPool } // CreateOutportDataProvider will create a new instance of outport.DataProviderOutport @@ -85,5 +87,7 @@ func CreateOutportDataProvider(arg ArgOutportDataProviderFactory) (outport.DataP ExecutionOrderHandler: arg.ExecutionOrderGetter, Hasher: arg.Hasher, Marshaller: arg.Marshaller, + ProofsPool: arg.ProofsPool, + EnableEpochsHandler: arg.EnableEpochsHandler, }) } diff --git a/outport/process/outportDataProvider.go b/outport/process/outportDataProvider.go index aec1f15df8b..3f3e63ef790 100644 --- a/outport/process/outportDataProvider.go +++ b/outport/process/outportDataProvider.go @@ -3,6 +3,7 @@ package process import ( "encoding/hex" "fmt" + "math/big" "github.com/multiversx/mx-chain-core-go/core" @@ -16,12 +17,14 @@ import ( "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/outport/process/alteredaccounts/shared" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("outport/process/outportDataProvider") @@ -39,6 +42,8 @@ type ArgOutportDataProvider struct { Marshaller marshal.Marshalizer Hasher hashing.Hasher ExecutionOrderHandler common.ExecutionOrderGetter + ProofsPool dataRetriever.ProofsPool + EnableEpochsHandler common.EnableEpochsHandler } // ArgPrepareOutportSaveBlockData holds the arguments needed for prepare outport save block data @@ -67,6 +72,8 @@ type outportDataProvider struct { executionOrderHandler common.ExecutionOrderGetter marshaller marshal.Marshalizer hasher hashing.Hasher + proofsPool dataRetriever.ProofsPool + enableEpochsHandler common.EnableEpochsHandler } // NewOutportDataProvider will create a new instance of outportDataProvider @@ -83,6 +90,8 @@ func NewOutportDataProvider(arg ArgOutportDataProvider) (*outportDataProvider, e executionOrderHandler: arg.ExecutionOrderHandler, marshaller: arg.Marshaller, hasher: arg.Hasher, + proofsPool: arg.ProofsPool, + enableEpochsHandler: arg.EnableEpochsHandler, }, nil } @@ -124,7 +133,7 @@ func (odp *outportDataProvider) PrepareOutportSaveBlockData(arg ArgPrepareOutpor return nil, fmt.Errorf("alteredAccountsProvider.ExtractAlteredAccountsFromPool %s", err) } - signersIndexes, err := odp.getSignersIndexes(arg.Header) + leaderBlsKey, leaderIndex, signersIndexes, err := odp.getSignersIndexes(arg.Header) if err != nil { return nil, err } @@ -134,7 +143,7 @@ func (odp *outportDataProvider) PrepareOutportSaveBlockData(arg ArgPrepareOutpor return nil, err } - return &outportcore.OutportBlockWithHeaderAndBody{ + outportBlock := &outportcore.OutportBlockWithHeaderAndBody{ OutportBlock: &outportcore.OutportBlock{ ShardID: odp.shardID, BlockData: nil, // this will be filled with specific data for each driver @@ -152,6 +161,8 @@ func (odp *outportDataProvider) PrepareOutportSaveBlockData(arg ArgPrepareOutpor HighestFinalBlockNonce: arg.HighestFinalBlockNonce, HighestFinalBlockHash: arg.HighestFinalBlockHash, + LeaderIndex: leaderIndex, + LeaderBLSKey: []byte(leaderBlsKey), }, HeaderDataWithBody: &outportcore.HeaderDataWithBody{ Body: arg.Body, @@ -159,7 +170,18 @@ func (odp *outportDataProvider) PrepareOutportSaveBlockData(arg ArgPrepareOutpor HeaderHash: arg.HeaderHash, IntraShardMiniBlocks: intraMiniBlocks, }, - }, nil + } + + if odp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, arg.Header.GetEpoch()) { + headerProof, err := odp.proofsPool.GetProof(arg.Header.GetShardID(), arg.HeaderHash) + if err != nil { + return nil, err + } + + outportBlock.HeaderDataWithBody.HeaderProof = headerProof + } + + return outportBlock, nil } func collectExecutedTxHashes(bodyHandler data.BodyHandler, headerHandler data.HeaderHandler) (map[string]struct{}, error) { @@ -290,24 +312,41 @@ func (odp *outportDataProvider) computeEpoch(header data.HeaderHandler) uint32 { return epoch } -func (odp *outportDataProvider) getSignersIndexes(header data.HeaderHandler) ([]uint64, error) { +func (odp *outportDataProvider) getSignersIndexes(header data.HeaderHandler) (string, uint64, []uint64, error) { epoch := odp.computeEpoch(header) - pubKeys, err := odp.nodesCoordinator.GetConsensusValidatorsPublicKeys( + leader, pubKeys, err := odp.nodesCoordinator.GetConsensusValidatorsPublicKeys( header.GetPrevRandSeed(), header.GetRound(), odp.shardID, epoch, ) + if err != nil { - return nil, fmt.Errorf("nodesCoordinator.GetConsensusValidatorsPublicKeys %w", err) + return "", 0, nil, fmt.Errorf("nodesCoordinator.GetConsensusValidatorsPublicKeys %w", err) + } + + leaderIndex := findLeaderIndex(pubKeys, leader) + + signersIndexes := make([]uint64, 0) + // when Andromeda flag is enabled signer indices can be empty because all validators are in consensus group + if odp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + return leader, leaderIndex, signersIndexes, nil } - signersIndexes, err := odp.nodesCoordinator.GetValidatorsIndexes(pubKeys, epoch) + signersIndexes, err = odp.nodesCoordinator.GetValidatorsIndexes(pubKeys, epoch) if err != nil { - return nil, fmt.Errorf("nodesCoordinator.GetValidatorsIndexes %s", err) + return "", 0, nil, fmt.Errorf("nodesCoordinator.GetValidatorsIndexes %s", err) } + return leader, leaderIndex, signersIndexes, nil +} - return signersIndexes, nil +func findLeaderIndex(blsKeys []string, leaderBlsKey string) uint64 { + for i := 0; i < len(blsKeys); i++ { + if blsKeys[i] == leaderBlsKey { + return uint64(i) + } + } + return 0 } func (odp *outportDataProvider) createPool(rewardsTxs map[string]data.TransactionHandler) (*outportcore.TransactionPool, error) { diff --git a/outport/process/outportDataProvider_test.go b/outport/process/outportDataProvider_test.go index ef1422d230a..af0a64eaf04 100644 --- a/outport/process/outportDataProvider_test.go +++ b/outport/process/outportDataProvider_test.go @@ -12,16 +12,18 @@ import ( "github.com/multiversx/mx-chain-core-go/data/rewardTx" "github.com/multiversx/mx-chain-core-go/data/smartContractResult" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/outport/mock" "github.com/multiversx/mx-chain-go/outport/process/transactionsfee" "github.com/multiversx/mx-chain-go/testscommon" commonMocks "github.com/multiversx/mx-chain-go/testscommon/common" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/require" ) func createArgOutportDataProvider() ArgOutportDataProvider { @@ -45,6 +47,8 @@ func createArgOutportDataProvider() ArgOutportDataProvider { ExecutionOrderHandler: &commonMocks.TxExecutionOrderHandlerStub{}, Marshaller: &marshallerMock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, + ProofsPool: &dataRetriever.ProofsPoolMock{}, + EnableEpochsHandler: enableEpochsHandlerMock.NewEnableEpochsHandlerStubWithNoFlagsDefined(), } } @@ -84,8 +88,8 @@ func TestPrepareOutportSaveBlockData(t *testing.T) { arg := createArgOutportDataProvider() arg.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - GetValidatorsPublicKeysCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) { - return nil, nil + GetValidatorsPublicKeysCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (string, []string, error) { + return "", nil, nil }, GetValidatorsIndexesCalled: func(publicKeys []string, epoch uint32) ([]uint64, error) { return []uint64{0, 1}, nil @@ -128,8 +132,8 @@ func TestOutportDataProvider_GetIntraShardMiniBlocks(t *testing.T) { arg := createArgOutportDataProvider() arg.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - GetValidatorsPublicKeysCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) { - return nil, nil + GetValidatorsPublicKeysCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (string, []string, error) { + return "", nil, nil }, GetValidatorsIndexesCalled: func(publicKeys []string, epoch uint32) ([]uint64, error) { return []uint64{0, 1}, nil @@ -567,6 +571,22 @@ func Test_collectExecutedTxHashes(t *testing.T) { }) } +func TestFindLeaderIndex(t *testing.T) { + t.Parallel() + + leaderKey := "a" + keys := []string{"a", "b", "c", "d", "e", "f", "g"} + require.Equal(t, uint64(0), findLeaderIndex(keys, leaderKey)) + + leaderKey = "g" + keys = []string{"a", "b", "c", "d", "e", "f", "g"} + require.Equal(t, uint64(6), findLeaderIndex(keys, leaderKey)) + + leaderKey = "notFound" + keys = []string{"a", "b", "c", "d", "e", "f", "g"} + require.Equal(t, uint64(0), findLeaderIndex(keys, leaderKey)) +} + func createMbsAndMbHeaders(numPairs int, numTxsPerMb int) ([]*block.MiniBlock, []block.MiniBlockHeader) { mbs := make([]*block.MiniBlock, numPairs) mbHeaders := make([]block.MiniBlockHeader, numPairs) diff --git a/p2p/disabled/networkMessenger.go b/p2p/disabled/networkMessenger.go index 4f854d976bc..c66443a4379 100644 --- a/p2p/disabled/networkMessenger.go +++ b/p2p/disabled/networkMessenger.go @@ -4,6 +4,7 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/p2p" ) @@ -176,8 +177,8 @@ func (netMes *networkMessenger) SignUsingPrivateKey(_ []byte, _ []byte) ([]byte, } // ProcessReceivedMessage returns nil as it is disabled -func (netMes *networkMessenger) ProcessReceivedMessage(_ p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) error { - return nil +func (netMes *networkMessenger) ProcessReceivedMessage(_ p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) ([]byte, error) { + return nil, nil } // SetDebugger returns nil as it is disabled diff --git a/process/block/argProcessor.go b/process/block/argProcessor.go index df929214829..13823d3da83 100644 --- a/process/block/argProcessor.go +++ b/process/block/argProcessor.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + nodeFactory "github.com/multiversx/mx-chain-go/cmd/node/factory" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" @@ -28,6 +29,7 @@ type coreComponentsHolder interface { EnableEpochsHandler() common.EnableEpochsHandler RoundNotifier() process.RoundNotifier EnableRoundsHandler() process.EnableRoundsHandler + EpochChangeGracePeriodHandler() common.EpochChangeGracePeriodHandler RoundHandler() consensus.RoundHandler EconomicsData() process.EconomicsDataHandler ProcessStatusHandler() common.ProcessStatusHandler diff --git a/process/block/baseProcess.go b/process/block/baseProcess.go index 0e3c573b23d..7991c98b181 100644 --- a/process/block/baseProcess.go +++ b/process/block/baseProcess.go @@ -59,8 +59,10 @@ type nonceAndHashInfo struct { } type hdrInfo struct { - usedInBlock bool - hdr data.HeaderHandler + usedInBlock bool + hdr data.HeaderHandler + hasProof bool + hasProofRequested bool } type baseProcessor struct { @@ -102,17 +104,18 @@ type baseProcessor struct { blockProcessor blockProcessor txCounter *transactionCounter - outportHandler outport.OutportHandler - outportDataProvider outport.DataProviderOutport - historyRepo dblookupext.HistoryRepository - epochNotifier process.EpochNotifier - enableEpochsHandler common.EnableEpochsHandler - roundNotifier process.RoundNotifier - enableRoundsHandler process.EnableRoundsHandler - vmContainerFactory process.VirtualMachinesContainerFactory - vmContainer process.VirtualMachinesContainer - gasConsumedProvider gasConsumedProvider - economicsData process.EconomicsDataHandler + outportHandler outport.OutportHandler + outportDataProvider outport.DataProviderOutport + historyRepo dblookupext.HistoryRepository + epochNotifier process.EpochNotifier + enableEpochsHandler common.EnableEpochsHandler + roundNotifier process.RoundNotifier + enableRoundsHandler process.EnableRoundsHandler + vmContainerFactory process.VirtualMachinesContainerFactory + vmContainer process.VirtualMachinesContainer + gasConsumedProvider gasConsumedProvider + economicsData process.EconomicsDataHandler + epochChangeGracePeriodHandler common.EpochChangeGracePeriodHandler processDataTriesOnCommitEpoch bool lastRestartNonce uint64 @@ -123,6 +126,9 @@ type baseProcessor struct { mutNonceOfFirstCommittedBlock sync.RWMutex nonceOfFirstCommittedBlock core.OptionalUint64 extraDelayRequestBlockInfo time.Duration + + proofsPool dataRetriever.ProofsPool + chRcvAllHdrs chan bool } type bootStorerDataArgs struct { @@ -343,7 +349,10 @@ func addMissingNonces(diff int64, lastNonce uint64, maxNumNoncesToAdd int) []uin return missingNonces } -func displayHeader(headerHandler data.HeaderHandler) []*display.LineData { +func displayHeader( + headerHandler data.HeaderHandler, + headerProof data.HeaderProofHandler, +) []*display.LineData { var valStatRootHash, epochStartMetaHash, scheduledRootHash []byte metaHeader, isMetaHeader := headerHandler.(data.MetaHeaderHandler) if isMetaHeader { @@ -358,7 +367,22 @@ func displayHeader(headerHandler data.HeaderHandler) []*display.LineData { if !check.IfNil(additionalData) { scheduledRootHash = additionalData.GetScheduledRootHash() } - return []*display.LineData{ + + var aggregatedSig, bitmap []byte + var proofShard, proofEpoch uint32 + var proofRound, proofNonce uint64 + var isStartOfEpoch, hasProofInfo bool + if !check.IfNil(headerProof) { + hasProofInfo = true + aggregatedSig, bitmap = headerProof.GetAggregatedSignature(), headerProof.GetPubKeysBitmap() + proofShard = headerProof.GetHeaderShardId() + proofEpoch = headerProof.GetHeaderEpoch() + proofRound = headerProof.GetHeaderRound() + proofNonce = headerProof.GetHeaderNonce() + isStartOfEpoch = headerProof.GetIsStartOfEpoch() + } + + logLines := []*display.LineData{ display.NewLineData(false, []string{ "", "ChainID", @@ -424,6 +448,41 @@ func displayHeader(headerHandler data.HeaderHandler) []*display.LineData { "Epoch start meta hash", logger.DisplayByteSlice(epochStartMetaHash)}), } + + if hasProofInfo { + logLines = append(logLines, + display.NewLineData(false, []string{ + "Header proof", + "Aggregated signature", + logger.DisplayByteSlice(aggregatedSig)}), + display.NewLineData(false, []string{ + "", + "Pub keys bitmap", + logger.DisplayByteSlice(bitmap)}), + display.NewLineData(false, []string{ + "", + "Epoch", + fmt.Sprintf("%d", proofEpoch)}), + display.NewLineData(false, []string{ + "", + "Round", + fmt.Sprintf("%d", proofRound)}), + display.NewLineData(false, []string{ + "", + "Shard", + fmt.Sprintf("%d", proofShard)}), + display.NewLineData(false, []string{ + "", + "Nonce", + fmt.Sprintf("%d", proofNonce)}), + display.NewLineData(true, []string{ + "", + "IsStartOfEpoch", + fmt.Sprintf("%t", isStartOfEpoch)}), + ) + } + + return logLines } // checkProcessorParameters will check the input parameters values @@ -520,10 +579,14 @@ func checkProcessorParameters(arguments ArgBaseProcessor) error { common.ScheduledMiniBlocksFlag, common.StakingV2Flag, common.CurrentRandomnessOnSortingFlag, + common.AndromedaFlag, }) if err != nil { return err } + if check.IfNil(arguments.CoreComponents.EpochChangeGracePeriodHandler()) { + return process.ErrNilEpochChangeGracePeriodHandler + } if check.IfNil(arguments.CoreComponents.RoundNotifier()) { return process.ErrNilRoundNotifier } @@ -596,41 +659,108 @@ func (bp *baseProcessor) verifyFees(header data.HeaderHandler) error { return nil } -// TODO: remove bool parameter and give instead the set to sort -func (bp *baseProcessor) sortHeadersForCurrentBlockByNonce(usedInBlock bool) map[uint32][]data.HeaderHandler { +func (bp *baseProcessor) filterHeadersWithoutProofs() (map[string]*hdrInfo, error) { + removedNonces := make(map[uint32]map[uint64]struct{}) + noncesWithProofs := make(map[uint32]map[uint64]struct{}) + shardIDs := common.GetShardIDs(bp.shardCoordinator.NumberOfShards()) + for shard := range shardIDs { + removedNonces[shard] = make(map[uint64]struct{}) + noncesWithProofs[shard] = make(map[uint64]struct{}) + } + filteredHeadersInfo := make(map[string]*hdrInfo) + + for hdrHash, headerInfo := range bp.hdrsForCurrBlock.hdrHashAndInfo { + if bp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, headerInfo.hdr.GetEpoch()) { + if bp.hasMissingProof(headerInfo, hdrHash) { + removedNonces[headerInfo.hdr.GetShardID()][headerInfo.hdr.GetNonce()] = struct{}{} + continue + } + + noncesWithProofs[headerInfo.hdr.GetShardID()][headerInfo.hdr.GetNonce()] = struct{}{} + filteredHeadersInfo[hdrHash] = bp.hdrsForCurrBlock.hdrHashAndInfo[hdrHash] + continue + } + + filteredHeadersInfo[hdrHash] = bp.hdrsForCurrBlock.hdrHashAndInfo[hdrHash] + } + + for shard, nonces := range removedNonces { + for nonce := range nonces { + if _, ok := noncesWithProofs[shard][nonce]; !ok { + return nil, fmt.Errorf("%w for shard %d and nonce %d", process.ErrMissingHeaderProof, shard, nonce) + } + } + } + + return filteredHeadersInfo, nil +} + +func (bp *baseProcessor) computeHeadersForCurrentBlock(usedInBlock bool) (map[uint32][]data.HeaderHandler, error) { hdrsForCurrentBlock := make(map[uint32][]data.HeaderHandler) - bp.hdrsForCurrBlock.mutHdrsForBlock.RLock() - for _, headerInfo := range bp.hdrsForCurrBlock.hdrHashAndInfo { + hdrHashAndInfo, err := bp.filterHeadersWithoutProofs() + if err != nil { + return nil, err + } + + for hdrHash, headerInfo := range hdrHashAndInfo { if headerInfo.usedInBlock != usedInBlock { continue } - hdrsForCurrentBlock[headerInfo.hdr.GetShardID()] = append(hdrsForCurrentBlock[headerInfo.hdr.GetShardID()], headerInfo.hdr) - } - bp.hdrsForCurrBlock.mutHdrsForBlock.RUnlock() + if bp.hasMissingProof(headerInfo, hdrHash) { + return nil, fmt.Errorf("%w for header with hash %s", process.ErrMissingHeaderProof, hex.EncodeToString([]byte(hdrHash))) + } - // sort headers for each shard - for _, hdrsForShard := range hdrsForCurrentBlock { - process.SortHeadersByNonce(hdrsForShard) + hdrsForCurrentBlock[headerInfo.hdr.GetShardID()] = append(hdrsForCurrentBlock[headerInfo.hdr.GetShardID()], headerInfo.hdr) } - return hdrsForCurrentBlock + return hdrsForCurrentBlock, nil } -func (bp *baseProcessor) sortHeaderHashesForCurrentBlockByNonce(usedInBlock bool) map[uint32][][]byte { +func (bp *baseProcessor) computeHeadersForCurrentBlockInfo(usedInBlock bool) (map[uint32][]*nonceAndHashInfo, error) { hdrsForCurrentBlockInfo := make(map[uint32][]*nonceAndHashInfo) - bp.hdrsForCurrBlock.mutHdrsForBlock.RLock() for metaBlockHash, headerInfo := range bp.hdrsForCurrBlock.hdrHashAndInfo { if headerInfo.usedInBlock != usedInBlock { continue } + if bp.hasMissingProof(headerInfo, metaBlockHash) { + return nil, fmt.Errorf("%w for header with hash %s", process.ErrMissingHeaderProof, hex.EncodeToString([]byte(metaBlockHash))) + } + hdrsForCurrentBlockInfo[headerInfo.hdr.GetShardID()] = append(hdrsForCurrentBlockInfo[headerInfo.hdr.GetShardID()], &nonceAndHashInfo{nonce: headerInfo.hdr.GetNonce(), hash: []byte(metaBlockHash)}) } + + return hdrsForCurrentBlockInfo, nil +} + +// TODO: remove bool parameter and give instead the set to sort +func (bp *baseProcessor) sortHeadersForCurrentBlockByNonce(usedInBlock bool) (map[uint32][]data.HeaderHandler, error) { + bp.hdrsForCurrBlock.mutHdrsForBlock.RLock() + hdrsForCurrentBlock, err := bp.computeHeadersForCurrentBlock(usedInBlock) bp.hdrsForCurrBlock.mutHdrsForBlock.RUnlock() + if err != nil { + return nil, err + } + + // sort headers for each shard + for _, hdrsForShard := range hdrsForCurrentBlock { + process.SortHeadersByNonce(hdrsForShard) + } + + return hdrsForCurrentBlock, nil +} + +func (bp *baseProcessor) sortHeaderHashesForCurrentBlockByNonce(usedInBlock bool) (map[uint32][][]byte, error) { + bp.hdrsForCurrBlock.mutHdrsForBlock.RLock() + hdrsForCurrentBlockInfo, err := bp.computeHeadersForCurrentBlockInfo(usedInBlock) + bp.hdrsForCurrBlock.mutHdrsForBlock.RUnlock() + if err != nil { + return nil, err + } for _, hdrsForShard := range hdrsForCurrentBlockInfo { if len(hdrsForShard) > 1 { @@ -647,7 +777,16 @@ func (bp *baseProcessor) sortHeaderHashesForCurrentBlockByNonce(usedInBlock bool } } - return hdrsHashesForCurrentBlock + return hdrsHashesForCurrentBlock, nil +} + +func (bp *baseProcessor) hasMissingProof(headerInfo *hdrInfo, hdrHash string) bool { + isFlagEnabledForHeader := bp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, headerInfo.hdr.GetEpoch()) && headerInfo.hdr.GetNonce() >= 1 + if !isFlagEnabledForHeader { + return false + } + + return !bp.proofsPool.HasProof(headerInfo.hdr.GetShardID(), []byte(hdrHash)) } func (bp *baseProcessor) createMiniBlockHeaderHandlers( @@ -799,7 +938,6 @@ func isPartiallyExecuted( ) bool { processedMiniBlockInfo := processedMiniBlocksDestMeInfo[string(miniBlockHeaderHandler.GetHash())] return processedMiniBlockInfo != nil && !processedMiniBlockInfo.FullyProcessed - } // check if header has the same miniblocks as presented in body @@ -957,7 +1095,18 @@ func (bp *baseProcessor) cleanupPools(headerHandler data.HeaderHandler) { bp.removeHeadersBehindNonceFromPools( true, bp.shardCoordinator.SelfId(), - highestPrevFinalBlockNonce) + highestPrevFinalBlockNonce, + ) + + if common.IsFlagEnabledAfterEpochsStartBlock(headerHandler, bp.enableEpochsHandler, common.AndromedaFlag) { + err := bp.dataPool.Proofs().CleanupProofsBehindNonce(bp.shardCoordinator.SelfId(), highestPrevFinalBlockNonce) + if err != nil { + log.Warn("failed to cleanup notarized proofs behind nonce", + "nonce", noncesToPrevFinal, + "shardID", bp.shardCoordinator.SelfId(), + "error", err) + } + } if bp.shardCoordinator.SelfId() == core.MetachainShardId { for shardID := uint32(0); shardID < bp.shardCoordinator.NumberOfShards(); shardID++ { @@ -966,6 +1115,7 @@ func (bp *baseProcessor) cleanupPools(headerHandler data.HeaderHandler) { } else { bp.cleanupPoolsForCrossShard(core.MetachainShardId, noncesToPrevFinal) } + } func (bp *baseProcessor) cleanupPoolsForCrossShard( @@ -986,6 +1136,16 @@ func (bp *baseProcessor) cleanupPoolsForCrossShard( shardID, crossNotarizedHeader.GetNonce(), ) + + if common.IsFlagEnabledAfterEpochsStartBlock(crossNotarizedHeader, bp.enableEpochsHandler, common.AndromedaFlag) { + err = bp.dataPool.Proofs().CleanupProofsBehindNonce(shardID, noncesToPrevFinal) + if err != nil { + log.Warn("failed to cleanup notarized proofs behind nonce", + "nonce", noncesToPrevFinal, + "shardID", shardID, + "error", err) + } + } } func (bp *baseProcessor) removeHeadersBehindNonceFromPools( @@ -1348,7 +1508,7 @@ func (bp *baseProcessor) saveShardHeader(header data.HeaderHandler, headerHash [ startTime := time.Now() nonceToByteSlice := bp.uint64Converter.ToByteSlice(header.GetNonce()) - hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(header.GetShardID()) + hdrNonceHashDataUnit := dataRetriever.GetHdrNonceHashDataUnit(header.GetShardID()) errNotCritical := bp.store.Put(hdrNonceHashDataUnit, nonceToByteSlice, headerHash) if errNotCritical != nil { @@ -1364,6 +1524,8 @@ func (bp *baseProcessor) saveShardHeader(header data.HeaderHandler, headerHash [ "err", errNotCritical) } + bp.saveProof(headerHash, header) + elapsedTime := time.Since(startTime) if elapsedTime >= common.PutInStorerMaxTime { log.Warn("saveShardHeader", "elapsed time", elapsedTime) @@ -1389,12 +1551,48 @@ func (bp *baseProcessor) saveMetaHeader(header data.HeaderHandler, headerHash [] "err", errNotCritical) } + bp.saveProof(headerHash, header) + elapsedTime := time.Since(startTime) if elapsedTime >= common.PutInStorerMaxTime { log.Warn("saveMetaHeader", "elapsed time", elapsedTime) } } +func (bp *baseProcessor) saveProof( + hash []byte, + header data.HeaderHandler, +) { + if !common.IsProofsFlagEnabledForHeader(bp.enableEpochsHandler, header) { + return + } + + proof, err := bp.proofsPool.GetProof(header.GetShardID(), hash) + if err != nil { + log.Error("could not find proof for header", + "hash", hex.EncodeToString(hash), + "shard", header.GetShardID(), + ) + return + } + marshalledProof, errNotCritical := bp.marshalizer.Marshal(proof) + if errNotCritical != nil { + logging.LogErrAsWarnExceptAsDebugIfClosingError(log, errNotCritical, + "saveProof.Marshal proof", + "err", errNotCritical) + return + } + + errNotCritical = bp.store.Put(dataRetriever.ProofsUnit, proof.GetHeaderHash(), marshalledProof) + if errNotCritical != nil { + logging.LogErrAsWarnExceptAsDebugIfClosingError(log, errNotCritical, + "saveProof.Put -> ProofsUnit", + "err", errNotCritical) + } + + log.Trace("saved proof to storage", "hash", hash) +} + func getLastSelfNotarizedHeaderByItself(chainHandler data.ChainHandler) (data.HeaderHandler, []byte) { currentHeader := chainHandler.GetCurrentBlockHeader() if check.IfNil(currentHeader) { @@ -1654,6 +1852,12 @@ func (bp *baseProcessor) restoreBlockBody(headerHandler data.HeaderHandler, body go bp.txCounter.headerReverted(headerHandler) } +// RemoveHeaderFromPool removes the header from the pool +func (bp *baseProcessor) RemoveHeaderFromPool(headerHash []byte) { + headersPool := bp.dataPool.Headers() + headersPool.RemoveHeaderByHash(headerHash) +} + // RestoreBlockBodyIntoPools restores the block body into associated pools func (bp *baseProcessor) RestoreBlockBodyIntoPools(bodyHandler data.BodyHandler) error { if check.IfNil(bodyHandler) { @@ -2119,7 +2323,7 @@ func (bp *baseProcessor) setNonceOfFirstCommittedBlock(nonce uint64) { } func (bp *baseProcessor) checkSentSignaturesAtCommitTime(header data.HeaderHandler) error { - validatorsGroup, err := headerCheck.ComputeConsensusGroup(header, bp.nodesCoordinator) + _, validatorsGroup, err := headerCheck.ComputeConsensusGroup(header, bp.nodesCoordinator) if err != nil { return err } @@ -2137,3 +2341,67 @@ func (bp *baseProcessor) checkSentSignaturesAtCommitTime(header data.HeaderHandl return nil } + +func (bp *baseProcessor) getHeaderHash(header data.HeaderHandler) ([]byte, error) { + marshalledHeader, errMarshal := bp.marshalizer.Marshal(header) + if errMarshal != nil { + return nil, errMarshal + } + + return bp.hasher.Compute(string(marshalledHeader)), nil +} + +func (bp *baseProcessor) requestProofIfNeeded(currentHeaderHash []byte, header data.HeaderHandler) bool { + if !bp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + return false + } + if bp.proofsPool.HasProof(header.GetShardID(), currentHeaderHash) { + _, ok := bp.hdrsForCurrBlock.hdrHashAndInfo[string(currentHeaderHash)] + if ok { + bp.hdrsForCurrBlock.hdrHashAndInfo[string(currentHeaderHash)].hasProof = true + } + + return true + } + + _, ok := bp.hdrsForCurrBlock.hdrHashAndInfo[string(currentHeaderHash)] + if !ok { + bp.hdrsForCurrBlock.hdrHashAndInfo[string(currentHeaderHash)] = &hdrInfo{ + hdr: header, + } + } + + bp.hdrsForCurrBlock.hdrHashAndInfo[string(currentHeaderHash)].hasProofRequested = true + bp.hdrsForCurrBlock.missingProofs++ + go bp.requestHandler.RequestEquivalentProofByHash(header.GetShardID(), currentHeaderHash) + + return false +} + +func (bp *baseProcessor) checkReceivedProofIfAttestingIsNeeded(proof data.HeaderProofHandler) { + bp.hdrsForCurrBlock.mutHdrsForBlock.Lock() + hdrHashAndInfo, ok := bp.hdrsForCurrBlock.hdrHashAndInfo[string(proof.GetHeaderHash())] + if !ok { + bp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() + return // proof not missing + } + + isWaitingForProofs := hdrHashAndInfo.hasProofRequested + if !isWaitingForProofs { + bp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() + return + } + + hdrHashAndInfo.hasProof = true + bp.hdrsForCurrBlock.missingProofs-- + + missingMetaHdrs := bp.hdrsForCurrBlock.missingHdrs + missingFinalityAttestingMetaHdrs := bp.hdrsForCurrBlock.missingFinalityAttestingHdrs + missingProofs := bp.hdrsForCurrBlock.missingProofs + bp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() + + allMissingMetaHeadersReceived := missingMetaHdrs == 0 && missingFinalityAttestingMetaHdrs == 0 && missingProofs == 0 + if allMissingMetaHeadersReceived { + bp.chRcvAllHdrs <- true + } +} diff --git a/process/block/baseProcess_test.go b/process/block/baseProcess_test.go index c52c5bece52..d4b02ea27d2 100644 --- a/process/block/baseProcess_test.go +++ b/process/block/baseProcess_test.go @@ -24,10 +24,16 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters/uint64ByteSlice" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/graceperiod" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/process" blproc "github.com/multiversx/mx-chain-go/process/block" "github.com/multiversx/mx-chain-go/process/block/bootstrapStorage" @@ -40,6 +46,7 @@ import ( "github.com/multiversx/mx-chain-go/storage/database" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" commonMocks "github.com/multiversx/mx-chain-go/testscommon/common" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" @@ -55,10 +62,10 @@ import ( stateMock "github.com/multiversx/mx-chain-go/testscommon/state" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) +var expectedErr = errors.New("expected error") + const ( busyIdentifier = "busy" idleIdentifier = "idle" @@ -76,8 +83,9 @@ func createArgBaseProcessor( ) blproc.ArgBaseProcessor { nodesCoordinatorInstance := shardingMocks.NewNodesCoordinatorMock() argsHeaderValidator := blproc.ArgsHeaderValidator{ - Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } headerValidator, _ := blproc.NewHeaderValidator(argsHeaderValidator) @@ -140,8 +148,8 @@ func createTestBlockchain() *testscommon.ChainHandlerStub { } func generateTestCache() storage.Cacher { - cache, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) - return cache + c, _ := storageunit.NewCache(storageunit.CacheConfig{Type: storageunit.LRUCache, Capacity: 1000, Shards: 1, SizeInBytes: 0}) + return c } func generateTestUnit() storage.Storer { @@ -160,7 +168,7 @@ func createShardedDataChacherNotifier( return func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, testHash) { return handler, true @@ -207,7 +215,7 @@ func initDataPool(testHash []byte) *dataRetrieverMock.PoolsHolderStub { UnsignedTransactionsCalled: unsignedTxCalled, RewardTransactionsCalled: rewardTransactionsCalled, MetaBlocksCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx1_hash")) { return &transaction.Transaction{Nonce: 10}, true @@ -234,7 +242,7 @@ func initDataPool(testHash []byte) *dataRetrieverMock.PoolsHolderStub { } }, MiniBlocksCalled: func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -284,6 +292,9 @@ func initDataPool(testHash []byte) *dataRetrieverMock.PoolsHolderStub { } return cs }, + ProofsCalled: func() dataRetriever.ProofsPool { + return proofscache.NewProofsPool(3, 100) + }, } return sdp @@ -379,17 +390,19 @@ func createComponentHolderMocks() ( blkc, _ := blockchain.NewBlockChain(&statusHandlerMock.AppStatusHandlerStub{}) _ = blkc.SetGenesisHeader(&block.Header{Nonce: 0}) + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) coreComponents := &mock.CoreComponentsMock{ - IntMarsh: &mock.MarshalizerMock{}, - Hash: &mock.HasherStub{}, - UInt64ByteSliceConv: &mock.Uint64ByteSliceConverterMock{}, - StatusField: &statusHandlerMock.AppStatusHandlerStub{}, - RoundField: &mock.RoundHandlerMock{}, - ProcessStatusHandlerField: &testscommon.ProcessStatusHandlerStub{}, - EpochNotifierField: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandlerField: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), - RoundNotifierField: &epochNotifier.RoundNotifierStub{}, - EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, + IntMarsh: &mock.MarshalizerMock{}, + Hash: &mock.HasherStub{}, + UInt64ByteSliceConv: &mock.Uint64ByteSliceConverterMock{}, + StatusField: &statusHandlerMock.AppStatusHandlerStub{}, + RoundField: &mock.RoundHandlerMock{}, + ProcessStatusHandlerField: &testscommon.ProcessStatusHandlerStub{}, + EpochNotifierField: &epochNotifier.EpochNotifierStub{}, + EnableEpochsHandlerField: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), + RoundNotifierField: &epochNotifier.RoundNotifierStub{}, + EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, + EpochChangeGracePeriodHandlerField: gracePeriod, } dataComponents := &mock.DataComponentsMock{ @@ -777,6 +790,38 @@ func TestCheckProcessorNilParameters(t *testing.T) { }, expectedErr: process.ErrNilManagedPeersHolder, }, + { + args: func() blproc.ArgBaseProcessor { + args := createArgBaseProcessor(coreComponents, dataComponents, bootstrapComponents, statusComponents) + args.OutportDataProvider = nil + return args + }, + expectedErr: process.ErrNilOutportDataProvider, + }, + { + args: func() blproc.ArgBaseProcessor { + args := createArgBaseProcessor(coreComponents, dataComponents, bootstrapComponents, statusComponents) + args.BlockProcessingCutoffHandler = nil + return args + }, + expectedErr: process.ErrNilBlockProcessingCutoffHandler, + }, + { + args: func() blproc.ArgBaseProcessor { + args := createArgBaseProcessor(coreComponents, dataComponents, bootstrapComponents, statusComponents) + args.ManagedPeersHolder = nil + return args + }, + expectedErr: process.ErrNilManagedPeersHolder, + }, + { + args: func() blproc.ArgBaseProcessor { + args := createArgBaseProcessor(coreComponents, dataComponents, bootstrapComponents, statusComponents) + args.SentSignaturesTracker = nil + return args + }, + expectedErr: process.ErrNilSentSignatureTracker, + }, { args: func() blproc.ArgBaseProcessor { return createArgBaseProcessor(coreComponents, dataComponents, bootstrapComponents, statusComponents) @@ -1262,8 +1307,9 @@ func TestBaseProcessor_SaveLastNotarizedHdrShardGood(t *testing.T) { sp, _ := blproc.NewShardProcessor(arguments) argsHeaderValidator := blproc.ArgsHeaderValidator{ - Hasher: coreComponents.Hasher(), - Marshalizer: coreComponents.InternalMarshalizer(), + Hasher: coreComponents.Hasher(), + Marshalizer: coreComponents.InternalMarshalizer(), + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), } headerValidator, _ := blproc.NewHeaderValidator(argsHeaderValidator) sp.SetHeaderValidator(headerValidator) @@ -1296,8 +1342,9 @@ func TestBaseProcessor_SaveLastNotarizedHdrMetaGood(t *testing.T) { sp, _ := blproc.NewShardProcessor(arguments) argsHeaderValidator := blproc.ArgsHeaderValidator{ - Hasher: coreComponents.Hasher(), - Marshalizer: coreComponents.InternalMarshalizer(), + Hasher: coreComponents.Hasher(), + Marshalizer: coreComponents.InternalMarshalizer(), + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), } headerValidator, _ := blproc.NewHeaderValidator(argsHeaderValidator) sp.SetHeaderValidator(headerValidator) @@ -1961,7 +2008,6 @@ func TestBaseProcessor_commitTrieEpochRootHashIfNeeded_GetAllLeaves(t *testing.T arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) - expectedErr := errors.New("expected error") arguments.AccountsDB = map[state.AccountsDbIdentifier]state.AccountsAdapter{ state.UserAccountsState: &stateMock.AccountsStub{ RootHashCalled: func() ([]byte, error) { @@ -1999,7 +2045,6 @@ func TestBaseProcessor_commitTrieEpochRootHashIfNeeded_GetAllLeaves(t *testing.T arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) - expectedErr := errors.New("expected error") arguments.AccountsDB = map[state.AccountsDbIdentifier]state.AccountsAdapter{ state.UserAccountsState: &stateMock.AccountsStub{ RootHashCalled: func() ([]byte, error) { @@ -2686,7 +2731,6 @@ func TestBaseProcessor_checkScheduledMiniBlockValidity(t *testing.T) { coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() coreComponents.EnableEpochsHandlerField = enableEpochsHandlerMock.NewEnableEpochsHandlerStub(common.ScheduledMiniBlocksFlag) - expectedErr := errors.New("expected error") coreComponents.IntMarsh = &marshallerMock.MarshalizerStub{ MarshalCalled: func(obj interface{}) ([]byte, error) { return nil, expectedErr @@ -3118,11 +3162,10 @@ func TestBaseProcessor_ConcurrentCallsNonceOfFirstCommittedBlock(t *testing.T) { func TestBaseProcessor_CheckSentSignaturesAtCommitTime(t *testing.T) { t.Parallel() - expectedErr := errors.New("expected error") t.Run("nodes coordinator errors, should return error", func(t *testing.T) { nodesCoordinatorInstance := shardingMocks.NewNodesCoordinatorMock() - nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return nil, expectedErr + nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return nil, nil, expectedErr } arguments := CreateMockArguments(createComponentHolderMocks()) @@ -3134,7 +3177,10 @@ func TestBaseProcessor_CheckSentSignaturesAtCommitTime(t *testing.T) { arguments.NodesCoordinator = nodesCoordinatorInstance bp, _ := blproc.NewShardProcessor(arguments) - err := bp.CheckSentSignaturesAtCommitTime(&block.Header{}) + err := bp.CheckSentSignaturesAtCommitTime(&block.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + }) assert.Equal(t, expectedErr, err) }) t.Run("should work with bitmap", func(t *testing.T) { @@ -3143,8 +3189,8 @@ func TestBaseProcessor_CheckSentSignaturesAtCommitTime(t *testing.T) { validator2, _ := nodesCoordinator.NewValidator([]byte("pk2"), 2, 2) nodesCoordinatorInstance := shardingMocks.NewNodesCoordinatorMock() - nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{validator0, validator1, validator2}, nil + nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator0, []nodesCoordinator.Validator{validator0, validator1, validator2}, nil } resetCountersCalled := make([][]byte, 0) @@ -3158,6 +3204,8 @@ func TestBaseProcessor_CheckSentSignaturesAtCommitTime(t *testing.T) { bp, _ := blproc.NewShardProcessor(arguments) err := bp.CheckSentSignaturesAtCommitTime(&block.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), PubKeysBitmap: []byte{0b00000101}, }) assert.Nil(t, err) @@ -3165,3 +3213,181 @@ func TestBaseProcessor_CheckSentSignaturesAtCommitTime(t *testing.T) { assert.Equal(t, [][]byte{validator0.PubKey(), validator2.PubKey()}, resetCountersCalled) }) } + +func TestBaseProcessor_FilterHeadersWithoutProofs(t *testing.T) { + t.Parallel() + + headersForCurrentBlock := map[string]data.HeaderHandler{ + "hash0": &testscommon.HeaderHandlerStub{ + EpochField: 12, + GetNonceCalled: func() uint64 { + return 1 + }, + GetShardIDCalled: func() uint32 { + return 0 + }, + }, + "hash1": &testscommon.HeaderHandlerStub{ + EpochField: 12, + GetNonceCalled: func() uint64 { + return 1 + }, + GetShardIDCalled: func() uint32 { + return 1 + }, + }, + "hash2": &testscommon.HeaderHandlerStub{ + EpochField: 12, // no proof for this one, should be marked for deletion + GetNonceCalled: func() uint64 { + return 2 + }, + GetShardIDCalled: func() uint32 { + return 0 + }, + }, + "hash3": &testscommon.HeaderHandlerStub{ + EpochField: 1, // flag not active, for coverage only + GetNonceCalled: func() uint64 { + return 2 + }, + GetShardIDCalled: func() uint32 { + return 1 + }, + }, + } + coreComp, dataComp, bootstrapComp, statusComp := createComponentHolderMocks() + bootstrapComp.Coordinator = &mock.ShardCoordinatorStub{ + NumberOfShardsCalled: func() uint32 { + return 2 + }, + } + coreComp.EnableEpochsHandlerField = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return epoch == 12 + }, + } + dataPool := initDataPool([]byte("")) + dataPool.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return string(headerHash) != "hash2" + }, + } + } + dataComp.DataPool = dataPool + arguments := CreateMockArguments(coreComp, dataComp, bootstrapComp, statusComp) + bp, _ := blproc.NewShardProcessor(arguments) + + for hash, header := range headersForCurrentBlock { + bp.SetHdrForCurrentBlock([]byte(hash), header, true) + } + + // this call should fail because header with nonce 2 from shard 0 (hash2) does not have proof + // and there is no other header with the same nonce and proof + headersWithProofs, err := bp.FilterHeadersWithoutProofs() + require.True(t, errors.Is(err, process.ErrMissingHeaderProof)) + require.Nil(t, headersWithProofs) + + // add one more header with same nonce as hash2, but this one has proof + bp.SetHdrForCurrentBlock( + []byte("hash4"), + &testscommon.HeaderHandlerStub{ + EpochField: 12, // same nonce as above, but this one has proof + GetNonceCalled: func() uint64 { + return 2 + }, + GetShardIDCalled: func() uint32 { + return 0 + }, + }, + true, + ) + + // this call should succeed, as for nonce 2 in shard 0 we have 2 headers, hash2 and hash4, but hash4 has proof + headersWithProofs, err = bp.FilterHeadersWithoutProofs() + require.NoError(t, err) + require.Equal(t, 4, len(headersWithProofs)) + + returnedHashes := make([]string, 0, len(headersWithProofs)) + for hash := range headersWithProofs { + returnedHashes = append(returnedHashes, hash) + } + slices.Sort(returnedHashes) + + expectedSortedHashes := []string{"hash0", "hash1", "hash3", "hash4"} + require.Equal(t, expectedSortedHashes, returnedHashes) +} + +func TestBaseProcessor_DisplayHeader(t *testing.T) { + t.Parallel() + + t.Run("shard header with proof info", func(t *testing.T) { + t.Parallel() + + header := &block.HeaderV2{ + Header: &block.Header{ + ChainID: []byte("1"), + Epoch: 2, + Round: 3, + TimeStamp: 4, + Nonce: 5, + PrevHash: []byte("prevHash"), + PrevRandSeed: []byte("prevRandSeed"), + RandSeed: []byte("randSeed"), + LeaderSignature: []byte("leaderSig"), + RootHash: []byte("rootHash"), + ReceiptsHash: []byte("receiptsHash"), + }, + ScheduledRootHash: []byte("schRootHash"), + ScheduledAccumulatedFees: big.NewInt(6), + ScheduledDeveloperFees: big.NewInt(7), + ScheduledGasProvided: 8, + ScheduledGasPenalized: 9, + ScheduledGasRefunded: 10, + } + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("prevHash"), + HeaderEpoch: 2, + HeaderNonce: 4, + HeaderShardId: 0, + HeaderRound: 2, + IsStartOfEpoch: false, + } + + lines := blproc.DisplayHeader(header, proof) + require.Equal(t, 23, len(lines)) + }) + t.Run("meta header with proof info", func(t *testing.T) { + t.Parallel() + + header := &block.MetaBlock{ + Nonce: 5, + Epoch: 2, + Round: 3, + TimeStamp: 4, + LeaderSignature: []byte("leaderSig"), + PrevHash: []byte("prevHash"), + PrevRandSeed: []byte("prevRandSeed"), + RandSeed: []byte("randSeed"), + RootHash: []byte("rootHash"), + ReceiptsHash: []byte("receiptsHash"), + EpochStart: block.EpochStart{}, + ChainID: []byte("1"), + } + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("prevHash"), + HeaderEpoch: 2, + HeaderNonce: 4, + HeaderShardId: 0, + HeaderRound: 2, + IsStartOfEpoch: false, + } + + lines := blproc.DisplayHeader(header, proof) + require.Equal(t, 23, len(lines)) + }) +} diff --git a/process/block/bootstrapStorage/bootstrapData.pb.go b/process/block/bootstrapStorage/bootstrapData.pb.go index b27029a205e..d6b3ef8006f 100644 --- a/process/block/bootstrapStorage/bootstrapData.pb.go +++ b/process/block/bootstrapStorage/bootstrapData.pb.go @@ -26,7 +26,7 @@ var _ = math.Inf // proto package needs to be updated. const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package -//MiniBlocksInMeta is used to store all mini blocks hashes for a metablock hash +// MiniBlocksInMeta is used to store all mini blocks hashes for a metablock hash type MiniBlocksInMeta struct { MetaHash []byte `protobuf:"bytes,1,opt,name=MetaHash,proto3" json:"MetaHash,omitempty"` MiniBlocksHashes [][]byte `protobuf:"bytes,2,rep,name=MiniBlocksHashes,proto3" json:"MiniBlocksHashes,omitempty"` @@ -90,7 +90,7 @@ func (m *MiniBlocksInMeta) GetIndexOfLastTxProcessed() []int32 { return nil } -//BootstrapHeaderInfo is used to store information about a header +// BootstrapHeaderInfo is used to store information about a header type BootstrapHeaderInfo struct { ShardId uint32 `protobuf:"varint,1,opt,name=ShardId,proto3" json:"ShardId,omitempty"` Epoch uint32 `protobuf:"varint,2,opt,name=Epoch,proto3" json:"Epoch,omitempty"` @@ -154,7 +154,7 @@ func (m *BootstrapHeaderInfo) GetHash() []byte { return nil } -//PendingMiniBlocksInfo is used to store information about the number of pending miniblocks +// PendingMiniBlocksInfo is used to store information about the number of pending miniblocks type PendingMiniBlocksInfo struct { ShardID uint32 `protobuf:"varint,1,opt,name=ShardID,proto3" json:"ShardID,omitempty"` MiniBlocksHashes [][]byte `protobuf:"bytes,2,rep,name=MiniBlocksHashes,proto3" json:"MiniBlocksHashes,omitempty"` diff --git a/process/block/displayBlock.go b/process/block/displayBlock.go index 3b1ab7410cc..176114deab9 100644 --- a/process/block/displayBlock.go +++ b/process/block/displayBlock.go @@ -134,10 +134,11 @@ func (txc *transactionCounter) displayLogInfo( headerHash []byte, numShards uint32, selfId uint32, - _ dataRetriever.PoolsHolder, + dataPool dataRetriever.PoolsHolder, blockTracker process.BlockTracker, ) { - dispHeader, dispLines := txc.createDisplayableShardHeaderAndBlockBody(header, body) + headerProof, _ := dataPool.Proofs().GetProof(selfId, headerHash) + dispHeader, dispLines := txc.createDisplayableShardHeaderAndBlockBody(header, body, headerProof) tblString, err := display.CreateTableString(dispHeader, dispLines) if err != nil { @@ -162,6 +163,7 @@ func (txc *transactionCounter) displayLogInfo( func (txc *transactionCounter) createDisplayableShardHeaderAndBlockBody( header data.HeaderHandler, body *block.Body, + headerProof data.HeaderProofHandler, ) ([]string, []*display.LineData) { tableHeader := []string{"Part", "Parameter", "Value"} @@ -177,7 +179,7 @@ func (txc *transactionCounter) createDisplayableShardHeaderAndBlockBody( fmt.Sprintf("%d", header.GetShardID())}), } - lines := displayHeader(header) + lines := displayHeader(header, headerProof) shardLines := make([]*display.LineData, 0, len(lines)+len(headerLines)) shardLines = append(shardLines, headerLines...) @@ -268,7 +270,7 @@ func (txc *transactionCounter) displayTxBlockBody( miniBlock.SenderShardID, miniBlock.ReceiverShardID) - if miniBlock.TxHashes == nil || len(miniBlock.TxHashes) == 0 { + if len(miniBlock.TxHashes) == 0 { lines = append(lines, display.NewLineData(false, []string{ part, "", ""})) } diff --git a/process/block/displayMetaBlock.go b/process/block/displayMetaBlock.go index 2018b819925..d43cb9290e6 100644 --- a/process/block/displayMetaBlock.go +++ b/process/block/displayMetaBlock.go @@ -4,10 +4,12 @@ import ( "fmt" "sync" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/display" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-logger-go" ) @@ -70,6 +72,7 @@ func (hc *headersCounter) displayLogInfo( headerHash []byte, numShardHeadersFromPool int, blockTracker process.BlockTracker, + dataPool dataRetriever.PoolsHolder, ) { if check.IfNil(countersProvider) { log.Warn("programming error in headersCounter.displayLogInfo - nil countersProvider") @@ -78,7 +81,8 @@ func (hc *headersCounter) displayLogInfo( hc.calculateNumOfShardMBHeaders(header) - dispHeader, dispLines := hc.createDisplayableMetaHeader(header) + headerProof, _ := dataPool.Proofs().GetProof(core.MetachainShardId, headerHash) + dispHeader, dispLines := hc.createDisplayableMetaHeader(header, headerProof) dispLines = hc.displayTxBlockBody(dispLines, header, body) tblString, err := display.CreateTableString(dispHeader, dispLines) @@ -109,6 +113,7 @@ func (hc *headersCounter) displayLogInfo( func (hc *headersCounter) createDisplayableMetaHeader( header *block.MetaBlock, + headerProof data.HeaderProofHandler, ) ([]string, []*display.LineData) { tableHeader := []string{"Part", "Parameter", "Value"} @@ -122,9 +127,9 @@ func (hc *headersCounter) createDisplayableMetaHeader( var lines []*display.LineData if header.IsStartOfEpochBlock() { - lines = displayEpochStartMetaBlock(header) + lines = displayEpochStartMetaBlock(header, headerProof) } else { - lines = displayHeader(header) + lines = displayHeader(header, headerProof) } metaLines := make([]*display.LineData, 0, len(lines)+len(metaLinesHeader)) @@ -145,7 +150,7 @@ func (hc *headersCounter) displayShardInfo(lines []*display.LineData, header *bl "Header hash", logger.DisplayByteSlice(shardData.HeaderHash)})) - if shardData.ShardMiniBlockHeaders == nil || len(shardData.ShardMiniBlockHeaders) == 0 { + if len(shardData.ShardMiniBlockHeaders) == 0 { lines = append(lines, display.NewLineData(false, []string{ "", "ShardMiniBlockHeaders", ""})) } @@ -197,7 +202,7 @@ func (hc *headersCounter) displayTxBlockBody( miniBlock.SenderShardID, miniBlock.ReceiverShardID) - if miniBlock.TxHashes == nil || len(miniBlock.TxHashes) == 0 { + if len(miniBlock.TxHashes) == 0 { lines = append(lines, display.NewLineData(false, []string{ part, "", ""})) } @@ -240,8 +245,11 @@ func (hc *headersCounter) getNumShardMBHeadersTotalProcessed() uint64 { return hc.shardMBHeadersTotalProcessed } -func displayEpochStartMetaBlock(block *block.MetaBlock) []*display.LineData { - lines := displayHeader(block) +func displayEpochStartMetaBlock( + block *block.MetaBlock, + headerProof data.HeaderProofHandler, +) []*display.LineData { + lines := displayHeader(block, headerProof) economicsLines := displayEconomicsData(block.EpochStart.Economics) lines = append(lines, economicsLines...) diff --git a/process/block/export_test.go b/process/block/export_test.go index 2332115613c..750ff2ee15c 100644 --- a/process/block/export_test.go +++ b/process/block/export_test.go @@ -9,9 +9,12 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/scheduled" + "github.com/multiversx/mx-chain-core-go/display" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common/graceperiod" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/bootstrapStorage" @@ -31,14 +34,17 @@ import ( storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" ) +// ComputeHeaderHash - func (bp *baseProcessor) ComputeHeaderHash(hdr data.HeaderHandler) ([]byte, error) { return core.CalculateHash(bp.marshalizer, bp.hasher, hdr) } +// VerifyStateRoot - func (bp *baseProcessor) VerifyStateRoot(rootHash []byte) bool { return bp.verifyStateRoot(rootHash) } +// CheckBlockValidity - func (bp *baseProcessor) CheckBlockValidity( headerHandler data.HeaderHandler, bodyHandler data.BodyHandler, @@ -46,6 +52,7 @@ func (bp *baseProcessor) CheckBlockValidity( return bp.checkBlockValidity(headerHandler, bodyHandler) } +// RemoveHeadersBehindNonceFromPools - func (bp *baseProcessor) RemoveHeadersBehindNonceFromPools( shouldRemoveBlockBody bool, shardId uint32, @@ -54,42 +61,56 @@ func (bp *baseProcessor) RemoveHeadersBehindNonceFromPools( bp.removeHeadersBehindNonceFromPools(shouldRemoveBlockBody, shardId, nonce) } +// GetPruningHandler - func (bp *baseProcessor) GetPruningHandler(finalHeaderNonce uint64) state.PruningHandler { return bp.getPruningHandler(finalHeaderNonce) } +// SetLastRestartNonce - func (bp *baseProcessor) SetLastRestartNonce(lastRestartNonce uint64) { bp.lastRestartNonce = lastRestartNonce } +// CommitTrieEpochRootHashIfNeeded - func (bp *baseProcessor) CommitTrieEpochRootHashIfNeeded(metaBlock *block.MetaBlock, rootHash []byte) error { return bp.commitTrieEpochRootHashIfNeeded(metaBlock, rootHash) } +// FilterHeadersWithoutProofs - +func (bp *baseProcessor) FilterHeadersWithoutProofs() (map[string]*hdrInfo, error) { + return bp.filterHeadersWithoutProofs() +} + +// ReceivedMetaBlock - func (sp *shardProcessor) ReceivedMetaBlock(header data.HeaderHandler, metaBlockHash []byte) { sp.receivedMetaBlock(header, metaBlockHash) } +// CreateMiniBlocks - func (sp *shardProcessor) CreateMiniBlocks(haveTime func() bool) (*block.Body, map[string]*processedMb.ProcessedMiniBlockInfo, error) { return sp.createMiniBlocks(haveTime, []byte("random")) } +// GetOrderedProcessedMetaBlocksFromHeader - func (sp *shardProcessor) GetOrderedProcessedMetaBlocksFromHeader(header data.HeaderHandler) ([]data.HeaderHandler, error) { return sp.getOrderedProcessedMetaBlocksFromHeader(header) } +// UpdateCrossShardInfo - func (sp *shardProcessor) UpdateCrossShardInfo(processedMetaHdrs []data.HeaderHandler) error { return sp.updateCrossShardInfo(processedMetaHdrs) } -func (sp *shardProcessor) UpdateStateStorage(finalHeaders []data.HeaderHandler, currentHeader data.HeaderHandler) { +// UpdateStateStorage - +func (sp *shardProcessor) UpdateStateStorage(finalHeaders []data.HeaderHandler, currentHeader data.HeaderHandler, currentHeaderHash []byte) { currShardHeader, ok := currentHeader.(data.ShardHeaderHandler) if !ok { return } - sp.updateState(finalHeaders, currShardHeader) + sp.updateState(finalHeaders, currShardHeader, currentHeaderHash) } +// NewShardProcessorEmptyWith3shards - func NewShardProcessorEmptyWith3shards( tdp dataRetriever.PoolsHolder, genesisBlocks map[uint32]data.HeaderHandler, @@ -99,25 +120,27 @@ func NewShardProcessorEmptyWith3shards( nodesCoordinator := shardingMocks.NewNodesCoordinatorMock() argsHeaderValidator := ArgsHeaderValidator{ - Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } hdrValidator, _ := NewHeaderValidator(argsHeaderValidator) accountsDb := make(map[state.AccountsDbIdentifier]state.AccountsAdapter) accountsDb[state.UserAccountsState] = &stateMock.AccountsStub{} - + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) coreComponents := &mock.CoreComponentsMock{ - IntMarsh: &mock.MarshalizerMock{}, - Hash: &hashingMocks.HasherMock{}, - UInt64ByteSliceConv: &mock.Uint64ByteSliceConverterMock{}, - StatusField: &statusHandlerMock.AppStatusHandlerStub{}, - RoundField: &mock.RoundHandlerMock{}, - ProcessStatusHandlerField: &testscommon.ProcessStatusHandlerStub{}, - EpochNotifierField: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandlerField: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), - RoundNotifierField: &epochNotifier.RoundNotifierStub{}, - EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, + IntMarsh: &mock.MarshalizerMock{}, + Hash: &hashingMocks.HasherMock{}, + UInt64ByteSliceConv: &mock.Uint64ByteSliceConverterMock{}, + StatusField: &statusHandlerMock.AppStatusHandlerStub{}, + RoundField: &mock.RoundHandlerMock{}, + ProcessStatusHandlerField: &testscommon.ProcessStatusHandlerStub{}, + EpochNotifierField: &epochNotifier.EpochNotifierStub{}, + EnableEpochsHandlerField: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), + RoundNotifierField: &epochNotifier.RoundNotifierStub{}, + EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, + EpochChangeGracePeriodHandlerField: gracePeriod, } dataComponents := &mock.DataComponentsMock{ Storage: &storageStubs.ChainStorerStub{}, @@ -175,18 +198,22 @@ func NewShardProcessorEmptyWith3shards( return shardProc, err } -func (mp *metaProcessor) RequestBlockHeaders(header *block.MetaBlock) (uint32, uint32) { +// RequestBlockHeaders - +func (mp *metaProcessor) RequestBlockHeaders(header *block.MetaBlock) (uint32, uint32, uint32) { return mp.requestShardHeaders(header) } +// ReceivedShardHeader - func (mp *metaProcessor) ReceivedShardHeader(header data.HeaderHandler, shardHeaderHash []byte) { mp.receivedShardHeader(header, shardHeaderHash) } +// GetDataPool - func (mp *metaProcessor) GetDataPool() dataRetriever.PoolsHolder { return mp.dataPool } +// AddHdrHashToRequestedList - func (mp *metaProcessor) AddHdrHashToRequestedList(hdr data.HeaderHandler, hdrHash []byte) { mp.hdrsForCurrBlock.mutHdrsForBlock.Lock() defer mp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() @@ -203,6 +230,7 @@ func (mp *metaProcessor) AddHdrHashToRequestedList(hdr data.HeaderHandler, hdrHa mp.hdrsForCurrBlock.missingHdrs++ } +// IsHdrMissing - func (mp *metaProcessor) IsHdrMissing(hdrHash []byte) bool { mp.hdrsForCurrBlock.mutHdrsForBlock.RLock() defer mp.hdrsForCurrBlock.mutHdrsForBlock.RUnlock() @@ -215,10 +243,12 @@ func (mp *metaProcessor) IsHdrMissing(hdrHash []byte) bool { return check.IfNil(hdrInfoValue.hdr) } +// CreateShardInfo - func (mp *metaProcessor) CreateShardInfo() ([]data.ShardDataHandler, error) { return mp.createShardInfo() } +// RequestMissingFinalityAttestingShardHeaders - func (mp *metaProcessor) RequestMissingFinalityAttestingShardHeaders() uint32 { mp.hdrsForCurrBlock.mutHdrsForBlock.Lock() defer mp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() @@ -226,10 +256,12 @@ func (mp *metaProcessor) RequestMissingFinalityAttestingShardHeaders() uint32 { return mp.requestMissingFinalityAttestingShardHeaders() } +// SaveMetricCrossCheckBlockHeight - func (mp *metaProcessor) SaveMetricCrossCheckBlockHeight() { mp.saveMetricCrossCheckBlockHeight() } +// NotarizedHdrs - func (bp *baseProcessor) NotarizedHdrs() map[uint32][]data.HeaderHandler { lastCrossNotarizedHeaders := make(map[uint32][]data.HeaderHandler) for shardID := uint32(0); shardID < bp.shardCoordinator.NumberOfShards(); shardID++ { @@ -247,6 +279,7 @@ func (bp *baseProcessor) NotarizedHdrs() map[uint32][]data.HeaderHandler { return lastCrossNotarizedHeaders } +// LastNotarizedHdrForShard - func (bp *baseProcessor) LastNotarizedHdrForShard(shardID uint32) data.HeaderHandler { lastCrossNotarizedHeaderForShard, _, _ := bp.blockTracker.GetLastCrossNotarizedHeader(shardID) if check.IfNil(lastCrossNotarizedHeaderForShard) { @@ -256,76 +289,94 @@ func (bp *baseProcessor) LastNotarizedHdrForShard(shardID uint32) data.HeaderHan return lastCrossNotarizedHeaderForShard } +// SetMarshalizer - func (bp *baseProcessor) SetMarshalizer(marshal marshal.Marshalizer) { bp.marshalizer = marshal } +// SetHasher - func (bp *baseProcessor) SetHasher(hasher hashing.Hasher) { bp.hasher = hasher } +// SetHeaderValidator - func (bp *baseProcessor) SetHeaderValidator(validator process.HeaderConstructionValidator) { bp.headerValidator = validator } +// RequestHeadersIfMissing - func (bp *baseProcessor) RequestHeadersIfMissing(sortedHdrs []data.HeaderHandler, shardId uint32) error { return bp.requestHeadersIfMissing(sortedHdrs, shardId) } +// SetShardBlockFinality - func (mp *metaProcessor) SetShardBlockFinality(val uint32) { mp.hdrsForCurrBlock.mutHdrsForBlock.Lock() mp.shardBlockFinality = val mp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() } +// SaveLastNotarizedHeader - func (mp *metaProcessor) SaveLastNotarizedHeader(header *block.MetaBlock) error { return mp.saveLastNotarizedHeader(header) } +// CheckShardHeadersValidity - func (mp *metaProcessor) CheckShardHeadersValidity(header *block.MetaBlock) (map[uint32]data.HeaderHandler, error) { return mp.checkShardHeadersValidity(header) } +// CheckShardHeadersFinality - func (mp *metaProcessor) CheckShardHeadersFinality(highestNonceHdrs map[uint32]data.HeaderHandler) error { return mp.checkShardHeadersFinality(highestNonceHdrs) } +// CheckHeaderBodyCorrelation - func (mp *metaProcessor) CheckHeaderBodyCorrelation(hdr data.HeaderHandler, body *block.Body) error { return mp.checkHeaderBodyCorrelation(hdr.GetMiniBlockHeaderHandlers(), body) } +// IsHdrConstructionValid - func (bp *baseProcessor) IsHdrConstructionValid(currHdr, prevHdr data.HeaderHandler) error { return bp.headerValidator.IsHeaderConstructionValid(currHdr, prevHdr) } +// ChRcvAllHdrs - func (mp *metaProcessor) ChRcvAllHdrs() chan bool { return mp.chRcvAllHdrs } +// UpdateShardsHeadersNonce - func (mp *metaProcessor) UpdateShardsHeadersNonce(key uint32, value uint64) { mp.updateShardHeadersNonce(key, value) } +// GetShardsHeadersNonce - func (mp *metaProcessor) GetShardsHeadersNonce() *sync.Map { return mp.shardsHeadersNonce } +// SaveLastNotarizedHeader - func (sp *shardProcessor) SaveLastNotarizedHeader(shardId uint32, processedHdrs []data.HeaderHandler) error { return sp.saveLastNotarizedHeader(shardId, processedHdrs) } +// CheckHeaderBodyCorrelation - func (sp *shardProcessor) CheckHeaderBodyCorrelation(hdr data.HeaderHandler, body *block.Body) error { return sp.checkHeaderBodyCorrelation(hdr.GetMiniBlockHeaderHandlers(), body) } +// CheckAndRequestIfMetaHeadersMissing - func (sp *shardProcessor) CheckAndRequestIfMetaHeadersMissing() { sp.checkAndRequestIfMetaHeadersMissing() } +// GetHashAndHdrStruct - func (sp *shardProcessor) GetHashAndHdrStruct(header data.HeaderHandler, hash []byte) *hashAndHdr { return &hashAndHdr{header, hash} } +// RequestMissingFinalityAttestingHeaders - func (sp *shardProcessor) RequestMissingFinalityAttestingHeaders() uint32 { sp.hdrsForCurrBlock.mutHdrsForBlock.Lock() defer sp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() @@ -336,10 +387,12 @@ func (sp *shardProcessor) RequestMissingFinalityAttestingHeaders() uint32 { ) } +// CheckMetaHeadersValidityAndFinality - func (sp *shardProcessor) CheckMetaHeadersValidityAndFinality() error { return sp.checkMetaHeadersValidityAndFinality() } +// CreateAndProcessMiniBlocksDstMe - func (sp *shardProcessor) CreateAndProcessMiniBlocksDstMe( haveTime func() bool, ) (block.MiniBlockSlice, uint32, uint32, error) { @@ -347,6 +400,7 @@ func (sp *shardProcessor) CreateAndProcessMiniBlocksDstMe( return createAndProcessInfo.miniBlocks, createAndProcessInfo.numHdrsAdded, createAndProcessInfo.numTxsAdded, err } +// DisplayLogInfo - func (sp *shardProcessor) DisplayLogInfo( header data.HeaderHandler, body *block.Body, @@ -359,10 +413,12 @@ func (sp *shardProcessor) DisplayLogInfo( sp.txCounter.displayLogInfo(header, body, headerHash, numShards, selfId, dataPool, blockTracker) } +// GetHighestHdrForOwnShardFromMetachain - func (sp *shardProcessor) GetHighestHdrForOwnShardFromMetachain(processedHdrs []data.HeaderHandler) ([]data.HeaderHandler, [][]byte, error) { return sp.getHighestHdrForOwnShardFromMetachain(processedHdrs) } +// RestoreMetaBlockIntoPool - func (sp *shardProcessor) RestoreMetaBlockIntoPool( miniBlockHashes map[string]uint32, metaBlockHashes [][]byte, @@ -371,60 +427,94 @@ func (sp *shardProcessor) RestoreMetaBlockIntoPool( return sp.restoreMetaBlockIntoPool(headerHandler, miniBlockHashes, metaBlockHashes) } +// GetAllMiniBlockDstMeFromMeta - func (sp *shardProcessor) GetAllMiniBlockDstMeFromMeta( header data.ShardHeaderHandler, ) (map[string][]byte, error) { return sp.getAllMiniBlockDstMeFromMeta(header) } +// SetHdrForCurrentBlock - func (bp *baseProcessor) SetHdrForCurrentBlock(headerHash []byte, headerHandler data.HeaderHandler, usedInBlock bool) { bp.hdrsForCurrBlock.mutHdrsForBlock.Lock() bp.hdrsForCurrBlock.hdrHashAndInfo[string(headerHash)] = &hdrInfo{hdr: headerHandler, usedInBlock: usedInBlock} bp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() } +// SetHighestHdrNonceForCurrentBlock - func (bp *baseProcessor) SetHighestHdrNonceForCurrentBlock(shardId uint32, value uint64) { bp.hdrsForCurrBlock.mutHdrsForBlock.Lock() bp.hdrsForCurrBlock.highestHdrNonce[shardId] = value bp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() } +// LastNotarizedHeaderInfo - +type LastNotarizedHeaderInfo struct { + Header data.HeaderHandler + Hash []byte + NotarizedBasedOnProof bool + HasProof bool +} + +// SetLastNotarizedHeaderForShard - +func (bp *baseProcessor) SetLastNotarizedHeaderForShard(shardId uint32, info *LastNotarizedHeaderInfo) { + bp.hdrsForCurrBlock.mutHdrsForBlock.Lock() + lastNotarizedShardInfo := &lastNotarizedHeaderInfo{ + header: info.Header, + hash: info.Hash, + notarizedBasedOnProof: info.NotarizedBasedOnProof, + hasProof: info.HasProof, + } + bp.hdrsForCurrBlock.lastNotarizedShardHeaders[shardId] = lastNotarizedShardInfo + bp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() +} + +// CreateBlockStarted - func (bp *baseProcessor) CreateBlockStarted() error { return bp.createBlockStarted() } +// AddProcessedCrossMiniBlocksFromHeader - func (sp *shardProcessor) AddProcessedCrossMiniBlocksFromHeader(header data.HeaderHandler) error { return sp.addProcessedCrossMiniBlocksFromHeader(header) } +// VerifyCrossShardMiniBlockDstMe - func (mp *metaProcessor) VerifyCrossShardMiniBlockDstMe(header *block.MetaBlock) error { return mp.verifyCrossShardMiniBlockDstMe(header) } +// ApplyBodyToHeader - func (mp *metaProcessor) ApplyBodyToHeader(metaHdr data.MetaHeaderHandler, body *block.Body) (data.BodyHandler, error) { return mp.applyBodyToHeader(metaHdr, body) } +// ApplyBodyToHeader - func (sp *shardProcessor) ApplyBodyToHeader(shardHdr data.ShardHeaderHandler, body *block.Body, processedMiniBlocksDestMeInfo map[string]*processedMb.ProcessedMiniBlockInfo) (*block.Body, error) { return sp.applyBodyToHeader(shardHdr, body, processedMiniBlocksDestMeInfo) } +// CreateBlockBody - func (mp *metaProcessor) CreateBlockBody(metaBlock data.HeaderHandler, haveTime func() bool) (data.BodyHandler, error) { return mp.createBlockBody(metaBlock, haveTime) } +// CreateBlockBody - func (sp *shardProcessor) CreateBlockBody(shardHdr data.HeaderHandler, haveTime func() bool) (data.BodyHandler, map[string]*processedMb.ProcessedMiniBlockInfo, error) { return sp.createBlockBody(shardHdr, haveTime) } +// CheckEpochCorrectnessCrossChain - func (sp *shardProcessor) CheckEpochCorrectnessCrossChain() error { return sp.checkEpochCorrectnessCrossChain() } +// CheckEpochCorrectness - func (sp *shardProcessor) CheckEpochCorrectness(header *block.Header) error { return sp.checkEpochCorrectness(header) } +// GetBootstrapHeadersInfo - func (sp *shardProcessor) GetBootstrapHeadersInfo( selfNotarizedHeaders []data.HeaderHandler, selfNotarizedHeadersHashes [][]byte, @@ -432,18 +522,22 @@ func (sp *shardProcessor) GetBootstrapHeadersInfo( return sp.getBootstrapHeadersInfo(selfNotarizedHeaders, selfNotarizedHeadersHashes) } +// RequestMetaHeadersIfNeeded - func (sp *shardProcessor) RequestMetaHeadersIfNeeded(hdrsAdded uint32, lastMetaHdr data.HeaderHandler) { sp.requestMetaHeadersIfNeeded(hdrsAdded, lastMetaHdr) } +// RequestShardHeadersIfNeeded - func (mp *metaProcessor) RequestShardHeadersIfNeeded(hdrsAddedForShard map[uint32]uint32, lastShardHdr map[uint32]data.HeaderHandler) { mp.requestShardHeadersIfNeeded(hdrsAddedForShard, lastShardHdr) } +// AddHeaderIntoTrackerPool - func (bp *baseProcessor) AddHeaderIntoTrackerPool(nonce uint64, shardID uint32) { bp.addHeaderIntoTrackerPool(nonce, shardID) } +// UpdateState - func (bp *baseProcessor) UpdateState( finalHeader data.HeaderHandler, rootHash []byte, @@ -453,14 +547,17 @@ func (bp *baseProcessor) UpdateState( bp.updateStateStorage(finalHeader, rootHash, prevRootHash, accounts) } +// GasAndFeesDelta - func GasAndFeesDelta(initialGasAndFees, finalGasAndFees scheduled.GasAndFees) scheduled.GasAndFees { return gasAndFeesDelta(initialGasAndFees, finalGasAndFees) } +// RequestEpochStartInfo - func (sp *shardProcessor) RequestEpochStartInfo(header data.ShardHeaderHandler, haveTime func() time.Duration) error { return sp.requestEpochStartInfo(header, haveTime) } +// ProcessEpochStartMetaBlock - func (mp *metaProcessor) ProcessEpochStartMetaBlock( header *block.MetaBlock, body *block.Body, @@ -468,30 +565,37 @@ func (mp *metaProcessor) ProcessEpochStartMetaBlock( return mp.processEpochStartMetaBlock(header, body) } +// UpdateEpochStartHeader - func (mp *metaProcessor) UpdateEpochStartHeader(metaHdr *block.MetaBlock) error { return mp.updateEpochStartHeader(metaHdr) } +// CreateEpochStartBody - func (mp *metaProcessor) CreateEpochStartBody(metaBlock *block.MetaBlock) (data.BodyHandler, error) { return mp.createEpochStartBody(metaBlock) } +// GetIndexOfFirstMiniBlockToBeExecuted - func (bp *baseProcessor) GetIndexOfFirstMiniBlockToBeExecuted(header data.HeaderHandler) int { return bp.getIndexOfFirstMiniBlockToBeExecuted(header) } +// GetFinalMiniBlocks - func (bp *baseProcessor) GetFinalMiniBlocks(header data.HeaderHandler, body *block.Body) (*block.Body, error) { return bp.getFinalMiniBlocks(header, body) } +// GetScheduledMiniBlocksFromMe - func GetScheduledMiniBlocksFromMe(headerHandler data.HeaderHandler, bodyHandler data.BodyHandler) (block.MiniBlockSlice, error) { return getScheduledMiniBlocksFromMe(headerHandler, bodyHandler) } +// CheckScheduledMiniBlocksValidity - func (bp *baseProcessor) CheckScheduledMiniBlocksValidity(headerHandler data.HeaderHandler) error { return bp.checkScheduledMiniBlocksValidity(headerHandler) } +// SetMiniBlockHeaderReservedField - func (bp *baseProcessor) SetMiniBlockHeaderReservedField( miniBlock *block.MiniBlock, miniBlockHeaderHandler data.MiniBlockHeaderHandler, @@ -500,18 +604,22 @@ func (bp *baseProcessor) SetMiniBlockHeaderReservedField( return bp.setMiniBlockHeaderReservedField(miniBlock, miniBlockHeaderHandler, processedMiniBlocksDestMeInfo) } +// GetFinalMiniBlockHeaders - func (mp *metaProcessor) GetFinalMiniBlockHeaders(miniBlockHeaderHandlers []data.MiniBlockHeaderHandler) []data.MiniBlockHeaderHandler { return mp.getFinalMiniBlockHeaders(miniBlockHeaderHandlers) } +// CheckProcessorNilParameters - func CheckProcessorNilParameters(arguments ArgBaseProcessor) error { return checkProcessorParameters(arguments) } +// SetIndexOfFirstTxProcessed - func (bp *baseProcessor) SetIndexOfFirstTxProcessed(miniBlockHeaderHandler data.MiniBlockHeaderHandler) error { return bp.setIndexOfFirstTxProcessed(miniBlockHeaderHandler) } +// SetIndexOfLastTxProcessed - func (bp *baseProcessor) SetIndexOfLastTxProcessed( miniBlockHeaderHandler data.MiniBlockHeaderHandler, processedMiniBlocksDestMeInfo map[string]*processedMb.ProcessedMiniBlockInfo, @@ -519,10 +627,12 @@ func (bp *baseProcessor) SetIndexOfLastTxProcessed( return bp.setIndexOfLastTxProcessed(miniBlockHeaderHandler, processedMiniBlocksDestMeInfo) } +// GetProcessedMiniBlocksTracker - func (bp *baseProcessor) GetProcessedMiniBlocksTracker() process.ProcessedMiniBlocksTracker { return bp.processedMiniBlocksTracker } +// SetProcessingTypeAndConstructionStateForScheduledMb - func (bp *baseProcessor) SetProcessingTypeAndConstructionStateForScheduledMb( miniBlockHeaderHandler data.MiniBlockHeaderHandler, processedMiniBlocksDestMeInfo map[string]*processedMb.ProcessedMiniBlockInfo, @@ -530,6 +640,7 @@ func (bp *baseProcessor) SetProcessingTypeAndConstructionStateForScheduledMb( return bp.setProcessingTypeAndConstructionStateForScheduledMb(miniBlockHeaderHandler, processedMiniBlocksDestMeInfo) } +// SetProcessingTypeAndConstructionStateForNormalMb - func (bp *baseProcessor) SetProcessingTypeAndConstructionStateForNormalMb( miniBlockHeaderHandler data.MiniBlockHeaderHandler, processedMiniBlocksDestMeInfo map[string]*processedMb.ProcessedMiniBlockInfo, @@ -537,26 +648,32 @@ func (bp *baseProcessor) SetProcessingTypeAndConstructionStateForNormalMb( return bp.setProcessingTypeAndConstructionStateForNormalMb(miniBlockHeaderHandler, processedMiniBlocksDestMeInfo) } +// RollBackProcessedMiniBlockInfo - func (sp *shardProcessor) RollBackProcessedMiniBlockInfo(miniBlockHeader data.MiniBlockHeaderHandler, miniBlockHash []byte) { sp.rollBackProcessedMiniBlockInfo(miniBlockHeader, miniBlockHash) } +// SetProcessedMiniBlocksInfo - func (sp *shardProcessor) SetProcessedMiniBlocksInfo(miniBlockHashes [][]byte, metaBlockHash string, metaBlock *block.MetaBlock) { sp.setProcessedMiniBlocksInfo(miniBlockHashes, metaBlockHash, metaBlock) } +// GetIndexOfLastTxProcessedInMiniBlock - func (sp *shardProcessor) GetIndexOfLastTxProcessedInMiniBlock(miniBlockHash []byte, metaBlock *block.MetaBlock) int32 { return getIndexOfLastTxProcessedInMiniBlock(miniBlockHash, metaBlock) } +// RollBackProcessedMiniBlocksInfo - func (sp *shardProcessor) RollBackProcessedMiniBlocksInfo(headerHandler data.HeaderHandler, mapMiniBlockHashes map[string]uint32) { sp.rollBackProcessedMiniBlocksInfo(headerHandler, mapMiniBlockHashes) } +// CheckConstructionStateAndIndexesCorrectness - func (bp *baseProcessor) CheckConstructionStateAndIndexesCorrectness(mbh data.MiniBlockHeaderHandler) error { return checkConstructionStateAndIndexesCorrectness(mbh) } +// GetAllMarshalledTxs - func (mp *metaProcessor) GetAllMarshalledTxs(body *block.Body) map[string][][]byte { return mp.getAllMarshalledTxs(body) } @@ -582,12 +699,12 @@ func (mp *metaProcessor) ChannelReceiveAllHeaders() chan bool { } // ComputeExistingAndRequestMissingShardHeaders - -func (mp *metaProcessor) ComputeExistingAndRequestMissingShardHeaders(metaBlock *block.MetaBlock) (uint32, uint32) { +func (mp *metaProcessor) ComputeExistingAndRequestMissingShardHeaders(metaBlock *block.MetaBlock) (uint32, uint32, uint32) { return mp.computeExistingAndRequestMissingShardHeaders(metaBlock) } // ComputeExistingAndRequestMissingMetaHeaders - -func (sp *shardProcessor) ComputeExistingAndRequestMissingMetaHeaders(header data.ShardHeaderHandler) (uint32, uint32) { +func (sp *shardProcessor) ComputeExistingAndRequestMissingMetaHeaders(header data.ShardHeaderHandler) (uint32, uint32, uint32) { return sp.computeExistingAndRequestMissingMetaHeaders(header) } @@ -598,7 +715,7 @@ func (sp *shardProcessor) GetHdrForBlock() *hdrForBlock { // ChannelReceiveAllHeaders - func (sp *shardProcessor) ChannelReceiveAllHeaders() chan bool { - return sp.chRcvAllMetaHdrs + return sp.chRcvAllHdrs } // InitMaps - @@ -706,3 +823,11 @@ func (hfb *hdrForBlock) GetHdrHashAndInfo() map[string]*HdrInfo { return m } + +// DisplayHeader - +func DisplayHeader( + headerHandler data.HeaderHandler, + headerProof data.HeaderProofHandler, +) []*display.LineData { + return displayHeader(headerHandler, headerProof) +} diff --git a/process/block/hdrForBlock.go b/process/block/hdrForBlock.go index fd7384aedc7..da443cf4aab 100644 --- a/process/block/hdrForBlock.go +++ b/process/block/hdrForBlock.go @@ -6,18 +6,28 @@ import ( "github.com/multiversx/mx-chain-core-go/data" ) +type lastNotarizedHeaderInfo struct { + header data.HeaderHandler + hash []byte + notarizedBasedOnProof bool + hasProof bool +} + type hdrForBlock struct { missingHdrs uint32 missingFinalityAttestingHdrs uint32 + missingProofs uint32 highestHdrNonce map[uint32]uint64 mutHdrsForBlock sync.RWMutex hdrHashAndInfo map[string]*hdrInfo + lastNotarizedShardHeaders map[uint32]*lastNotarizedHeaderInfo } func newHdrForBlock() *hdrForBlock { return &hdrForBlock{ - hdrHashAndInfo: make(map[string]*hdrInfo), - highestHdrNonce: make(map[uint32]uint64), + hdrHashAndInfo: make(map[string]*hdrInfo), + highestHdrNonce: make(map[uint32]uint64), + lastNotarizedShardHeaders: make(map[uint32]*lastNotarizedHeaderInfo), } } @@ -25,6 +35,7 @@ func (hfb *hdrForBlock) initMaps() { hfb.mutHdrsForBlock.Lock() hfb.hdrHashAndInfo = make(map[string]*hdrInfo) hfb.highestHdrNonce = make(map[uint32]uint64) + hfb.lastNotarizedShardHeaders = make(map[uint32]*lastNotarizedHeaderInfo) hfb.mutHdrsForBlock.Unlock() } @@ -32,6 +43,7 @@ func (hfb *hdrForBlock) resetMissingHdrs() { hfb.mutHdrsForBlock.Lock() hfb.missingHdrs = 0 hfb.missingFinalityAttestingHdrs = 0 + hfb.missingProofs = 0 hfb.mutHdrsForBlock.Unlock() } diff --git a/process/block/headerValidator.go b/process/block/headerValidator.go index b39787c7a96..199b793b36b 100644 --- a/process/block/headerValidator.go +++ b/process/block/headerValidator.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/process" ) @@ -15,13 +16,15 @@ var _ process.HeaderConstructionValidator = (*headerValidator)(nil) // ArgsHeaderValidator are the arguments needed to create a new header validator type ArgsHeaderValidator struct { - Hasher hashing.Hasher - Marshalizer marshal.Marshalizer + Hasher hashing.Hasher + Marshalizer marshal.Marshalizer + EnableEpochsHandler core.EnableEpochsHandler } type headerValidator struct { - hasher hashing.Hasher - marshalizer marshal.Marshalizer + hasher hashing.Hasher + marshalizer marshal.Marshalizer + enableEpochsHandler core.EnableEpochsHandler } // NewHeaderValidator returns a new header validator @@ -32,10 +35,14 @@ func NewHeaderValidator(args ArgsHeaderValidator) (*headerValidator, error) { if check.IfNil(args.Marshalizer) { return nil, process.ErrNilMarshalizer } + if check.IfNil(args.EnableEpochsHandler) { + return nil, process.ErrNilEnableEpochsHandler + } return &headerValidator{ - hasher: args.Hasher, - marshalizer: args.Marshalizer, + hasher: args.Hasher, + marshalizer: args.Marshalizer, + enableEpochsHandler: args.EnableEpochsHandler, }, nil } diff --git a/process/block/interceptedBlocks/argInterceptedBlockHeader.go b/process/block/interceptedBlocks/argInterceptedBlockHeader.go index 50d5b2be82f..3e763e64ce4 100644 --- a/process/block/interceptedBlocks/argInterceptedBlockHeader.go +++ b/process/block/interceptedBlocks/argInterceptedBlockHeader.go @@ -3,18 +3,22 @@ package interceptedBlocks import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" ) // ArgInterceptedBlockHeader is the argument for the intercepted header type ArgInterceptedBlockHeader struct { - HdrBuff []byte - Marshalizer marshal.Marshalizer - Hasher hashing.Hasher - ShardCoordinator sharding.Coordinator - HeaderSigVerifier process.InterceptedHeaderSigVerifier - HeaderIntegrityVerifier process.HeaderIntegrityVerifier - ValidityAttester process.ValidityAttester - EpochStartTrigger process.EpochStartTriggerHandler + HdrBuff []byte + Marshalizer marshal.Marshalizer + Hasher hashing.Hasher + ShardCoordinator sharding.Coordinator + HeaderSigVerifier process.InterceptedHeaderSigVerifier + HeaderIntegrityVerifier process.HeaderIntegrityVerifier + ValidityAttester process.ValidityAttester + EpochStartTrigger process.EpochStartTriggerHandler + EnableEpochsHandler common.EnableEpochsHandler + EpochChangeGracePeriodHandler common.EpochChangeGracePeriodHandler } diff --git a/process/block/interceptedBlocks/common.go b/process/block/interceptedBlocks/common.go index f3d3f1e393f..69a1fcd3383 100644 --- a/process/block/interceptedBlocks/common.go +++ b/process/block/interceptedBlocks/common.go @@ -4,6 +4,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" ) @@ -39,6 +41,12 @@ func checkBlockHeaderArgument(arg *ArgInterceptedBlockHeader) error { if check.IfNil(arg.ValidityAttester) { return process.ErrNilValidityAttester } + if check.IfNil(arg.EnableEpochsHandler) { + return process.ErrNilEnableEpochsHandler + } + if check.IfNil(arg.EpochChangeGracePeriodHandler) { + return process.ErrNilEpochChangeGracePeriodHandler + } return nil } @@ -63,14 +71,19 @@ func checkMiniblockArgument(arg *ArgInterceptedMiniblock) error { return nil } -func checkHeaderHandler(hdr data.HeaderHandler) error { - if len(hdr.GetPubKeysBitmap()) == 0 { +func checkHeaderHandler( + hdr data.HeaderHandler, + enableEpochsHandler common.EnableEpochsHandler, +) error { + equivalentMessagesEnabled := enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, hdr.GetEpoch()) + + if len(hdr.GetPubKeysBitmap()) == 0 && !equivalentMessagesEnabled { return process.ErrNilPubKeysBitmap } if len(hdr.GetPrevHash()) == 0 { return process.ErrNilPreviousBlockHash } - if len(hdr.GetSignature()) == 0 { + if len(hdr.GetSignature()) == 0 && !equivalentMessagesEnabled { return process.ErrNilSignature } if len(hdr.GetRootHash()) == 0 { @@ -86,7 +99,14 @@ func checkHeaderHandler(hdr data.HeaderHandler) error { return hdr.CheckFieldsForNil() } -func checkMetaShardInfo(shardInfo []data.ShardDataHandler, coordinator sharding.Coordinator) error { +func checkMetaShardInfo( + shardInfo []data.ShardDataHandler, + coordinator sharding.Coordinator, +) error { + if coordinator.SelfId() != core.MetachainShardId { + return nil + } + for _, sd := range shardInfo { if sd.GetShardID() >= coordinator.NumberOfShards() && sd.GetShardID() != core.MetachainShardId { return process.ErrInvalidShardId diff --git a/process/block/interceptedBlocks/common_test.go b/process/block/interceptedBlocks/common_test.go index 02be37e9bde..5eb6bf9c4bf 100644 --- a/process/block/interceptedBlocks/common_test.go +++ b/process/block/interceptedBlocks/common_test.go @@ -1,28 +1,36 @@ package interceptedBlocks import ( - "errors" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/common/graceperiod" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - "github.com/stretchr/testify/assert" ) func createDefaultBlockHeaderArgument() *ArgInterceptedBlockHeader { + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) arg := &ArgInterceptedBlockHeader{ - ShardCoordinator: mock.NewOneShardCoordinatorMock(), - Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &mock.MarshalizerMock{}, - HdrBuff: []byte("test buffer"), - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, - HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, - ValidityAttester: &mock.ValidityAttesterStub{}, - EpochStartTrigger: &mock.EpochStartTriggerStub{}, + ShardCoordinator: mock.NewOneShardCoordinatorMock(), + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: &mock.MarshalizerMock{}, + HdrBuff: []byte("test buffer"), + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, + ValidityAttester: &mock.ValidityAttesterStub{}, + EpochStartTrigger: &mock.EpochStartTriggerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + EpochChangeGracePeriodHandler: gracePeriod, } return arg @@ -62,7 +70,7 @@ func createDefaultHeaderHandler() *testscommon.HeaderHandlerStub { } } -//-------- checkBlockHeaderArgument +// -------- checkBlockHeaderArgument func TestCheckBlockHeaderArgument_NilArgumentShouldErr(t *testing.T) { t.Parallel() @@ -127,6 +135,28 @@ func TestCheckBlockHeaderArgument_NilShardCoordinatorShouldErr(t *testing.T) { assert.Equal(t, process.ErrNilShardCoordinator, err) } +func TestCheckBlockHeaderArgument_NilHeaderIntegrityVerifierShouldErr(t *testing.T) { + t.Parallel() + + arg := createDefaultBlockHeaderArgument() + arg.HeaderIntegrityVerifier = nil + + err := checkBlockHeaderArgument(arg) + + assert.Equal(t, process.ErrNilHeaderIntegrityVerifier, err) +} + +func TestCheckBlockHeaderArgument_NilEpochStartTriggerShouldErr(t *testing.T) { + t.Parallel() + + arg := createDefaultBlockHeaderArgument() + arg.EpochStartTrigger = nil + + err := checkBlockHeaderArgument(arg) + + assert.Equal(t, process.ErrNilEpochStartTrigger, err) +} + func TestCheckBlockHeaderArgument_NilValidityAttesterShouldErr(t *testing.T) { t.Parallel() @@ -138,6 +168,17 @@ func TestCheckBlockHeaderArgument_NilValidityAttesterShouldErr(t *testing.T) { assert.Equal(t, process.ErrNilValidityAttester, err) } +func TestCheckBlockHeaderArgument_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + arg := createDefaultBlockHeaderArgument() + arg.EnableEpochsHandler = nil + + err := checkBlockHeaderArgument(arg) + + assert.Equal(t, process.ErrNilEnableEpochsHandler, err) +} + func TestCheckBlockHeaderArgument_ShouldWork(t *testing.T) { t.Parallel() @@ -148,7 +189,7 @@ func TestCheckBlockHeaderArgument_ShouldWork(t *testing.T) { assert.Nil(t, err) } -//-------- checkMiniblockArgument +// -------- checkMiniblockArgument func TestCheckMiniblockArgument_NilArgumentShouldErr(t *testing.T) { t.Parallel() @@ -212,7 +253,7 @@ func TestCheckMiniblockArgument_ShouldWork(t *testing.T) { assert.Nil(t, err) } -//-------- checkHeaderHandler +// -------- checkHeaderHandler func TestCheckHeaderHandler_NilPubKeysBitmapShouldErr(t *testing.T) { t.Parallel() @@ -222,7 +263,7 @@ func TestCheckHeaderHandler_NilPubKeysBitmapShouldErr(t *testing.T) { return nil } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, process.ErrNilPubKeysBitmap, err) } @@ -235,7 +276,7 @@ func TestCheckHeaderHandler_NilPrevHashShouldErr(t *testing.T) { return nil } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, process.ErrNilPreviousBlockHash, err) } @@ -248,7 +289,7 @@ func TestCheckHeaderHandler_NilSignatureShouldErr(t *testing.T) { return nil } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, process.ErrNilSignature, err) } @@ -261,7 +302,7 @@ func TestCheckHeaderHandler_NilRootHashErr(t *testing.T) { return nil } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, process.ErrNilRootHash, err) } @@ -274,7 +315,7 @@ func TestCheckHeaderHandler_NilRandSeedErr(t *testing.T) { return nil } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, process.ErrNilRandSeed, err) } @@ -287,7 +328,7 @@ func TestCheckHeaderHandler_NilPrevRandSeedErr(t *testing.T) { return nil } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, process.ErrNilPrevRandSeed, err) } @@ -295,13 +336,12 @@ func TestCheckHeaderHandler_NilPrevRandSeedErr(t *testing.T) { func TestCheckHeaderHandler_CheckFieldsForNilErrors(t *testing.T) { t.Parallel() - expectedErr := errors.New("expected error") hdr := createDefaultHeaderHandler() hdr.CheckFieldsForNilCalled = func() error { return expectedErr } - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Equal(t, expectedErr, err) } @@ -311,12 +351,12 @@ func TestCheckHeaderHandler_ShouldWork(t *testing.T) { hdr := createDefaultHeaderHandler() - err := checkHeaderHandler(hdr) + err := checkHeaderHandler(hdr, enableEpochsHandlerMock.NewEnableEpochsHandlerStub()) assert.Nil(t, err) } -//------- checkMetaShardInfo +// ------- checkMetaShardInfo func TestCheckMetaShardInfo_WithNilOrEmptyShouldReturnNil(t *testing.T) { t.Parallel() @@ -330,10 +370,23 @@ func TestCheckMetaShardInfo_WithNilOrEmptyShouldReturnNil(t *testing.T) { assert.Nil(t, err2) } +func TestCheckMetaShardInfo_ShouldNotCheckShardInfoForShards(t *testing.T) { + t.Parallel() + + shardCoordinator := mock.NewOneShardCoordinatorMock() + _ = shardCoordinator.SetSelfId(1) + + sd := block.ShardData{} + + err := checkMetaShardInfo([]data.ShardDataHandler{&sd}, shardCoordinator) + assert.Nil(t, err) +} + func TestCheckMetaShardInfo_WrongShardIdShouldErr(t *testing.T) { t.Parallel() shardCoordinator := mock.NewOneShardCoordinatorMock() + _ = shardCoordinator.SetSelfId(core.MetachainShardId) wrongShardId := uint32(2) sd := block.ShardData{ ShardID: wrongShardId, @@ -351,6 +404,7 @@ func TestCheckMetaShardInfo_WrongMiniblockSenderShardIdShouldErr(t *testing.T) { t.Parallel() shardCoordinator := mock.NewOneShardCoordinatorMock() + _ = shardCoordinator.SetSelfId(core.MetachainShardId) wrongShardId := uint32(2) miniBlock := block.MiniBlockHeader{ Hash: make([]byte, 0), @@ -375,6 +429,7 @@ func TestCheckMetaShardInfo_WrongMiniblockReceiverShardIdShouldErr(t *testing.T) t.Parallel() shardCoordinator := mock.NewOneShardCoordinatorMock() + _ = shardCoordinator.SetSelfId(core.MetachainShardId) wrongShardId := uint32(2) miniBlock := block.MiniBlockHeader{ Hash: make([]byte, 0), @@ -399,6 +454,8 @@ func TestCheckMetaShardInfo_ReservedPopulatedShouldErr(t *testing.T) { t.Parallel() shardCoordinator := mock.NewOneShardCoordinatorMock() + _ = shardCoordinator.SetSelfId(core.MetachainShardId) + miniBlock := block.MiniBlockHeader{ Hash: make([]byte, 0), ReceiverShardID: shardCoordinator.SelfId(), @@ -423,6 +480,7 @@ func TestCheckMetaShardInfo_OkValsShouldWork(t *testing.T) { t.Parallel() shardCoordinator := mock.NewOneShardCoordinatorMock() + _ = shardCoordinator.SetSelfId(core.MetachainShardId) miniBlock := block.MiniBlockHeader{ Hash: make([]byte, 0), ReceiverShardID: shardCoordinator.SelfId(), @@ -446,7 +504,57 @@ func TestCheckMetaShardInfo_OkValsShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- checkMiniBlocksHeaders +func TestCheckMetaShardInfo_WithMultipleShardData(t *testing.T) { + t.Parallel() + + t.Run("should return invalid shard id error, with multiple shard data", func(t *testing.T) { + t.Parallel() + + shardCoordinator := mock.NewOneShardCoordinatorMock() + _ = shardCoordinator.SetSelfId(core.MetachainShardId) + wrongShardId := uint32(2) + miniBlock1 := block.MiniBlockHeader{ + Hash: make([]byte, 0), + ReceiverShardID: wrongShardId, + SenderShardID: shardCoordinator.SelfId(), + TxCount: 0, + } + + miniBlock2 := block.MiniBlockHeader{ + Hash: make([]byte, 0), + ReceiverShardID: shardCoordinator.SelfId(), + SenderShardID: shardCoordinator.SelfId(), + TxCount: 0, + } + + sd1 := &block.ShardData{ + ShardID: shardCoordinator.SelfId(), + HeaderHash: nil, + ShardMiniBlockHeaders: []block.MiniBlockHeader{ + miniBlock2, + }, + TxCount: 0, + } + + sd2 := &block.ShardData{ + ShardID: shardCoordinator.SelfId(), + HeaderHash: nil, + ShardMiniBlockHeaders: []block.MiniBlockHeader{ + miniBlock1, + }, + TxCount: 0, + } + + err := checkMetaShardInfo( + []data.ShardDataHandler{sd1, sd2}, + shardCoordinator, + ) + + assert.Equal(t, process.ErrInvalidShardId, err) + }) +} + +// ------- checkMiniBlocksHeaders func TestCheckMiniBlocksHeaders_WithNilOrEmptyShouldReturnNil(t *testing.T) { t.Parallel() diff --git a/process/block/interceptedBlocks/errors.go b/process/block/interceptedBlocks/errors.go new file mode 100644 index 00000000000..afd3f50cf03 --- /dev/null +++ b/process/block/interceptedBlocks/errors.go @@ -0,0 +1,8 @@ +package interceptedBlocks + +import "errors" + +var ( + // ErrInvalidProof signals that an invalid proof has been provided + ErrInvalidProof = errors.New("invalid proof") +) diff --git a/process/block/interceptedBlocks/interceptedBlockHeader.go b/process/block/interceptedBlocks/interceptedBlockHeader.go index 81d78bef5c0..181a23f5ac0 100644 --- a/process/block/interceptedBlocks/interceptedBlockHeader.go +++ b/process/block/interceptedBlocks/interceptedBlockHeader.go @@ -6,9 +6,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" + logger "github.com/multiversx/mx-chain-logger-go" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" - logger "github.com/multiversx/mx-chain-logger-go" ) var _ process.HdrValidatorHandler = (*InterceptedHeader)(nil) @@ -17,15 +19,17 @@ var _ process.InterceptedData = (*InterceptedHeader)(nil) // InterceptedHeader represents the wrapper over HeaderWrapper struct. // It implements Newer and Hashed interfaces type InterceptedHeader struct { - hdr data.HeaderHandler - sigVerifier process.InterceptedHeaderSigVerifier - integrityVerifier process.HeaderIntegrityVerifier - hasher hashing.Hasher - shardCoordinator sharding.Coordinator - hash []byte - isForCurrentShard bool - validityAttester process.ValidityAttester - epochStartTrigger process.EpochStartTriggerHandler + hdr data.HeaderHandler + sigVerifier process.InterceptedHeaderSigVerifier + integrityVerifier process.HeaderIntegrityVerifier + hasher hashing.Hasher + shardCoordinator sharding.Coordinator + hash []byte + isForCurrentShard bool + validityAttester process.ValidityAttester + epochStartTrigger process.EpochStartTriggerHandler + enableEpochsHandler common.EnableEpochsHandler + epochChangeGracePeriodHandler common.EpochChangeGracePeriodHandler } // NewInterceptedHeader creates a new instance of InterceptedHeader struct @@ -41,13 +45,15 @@ func NewInterceptedHeader(arg *ArgInterceptedBlockHeader) (*InterceptedHeader, e } inHdr := &InterceptedHeader{ - hdr: hdr, - hasher: arg.Hasher, - sigVerifier: arg.HeaderSigVerifier, - integrityVerifier: arg.HeaderIntegrityVerifier, - shardCoordinator: arg.ShardCoordinator, - validityAttester: arg.ValidityAttester, - epochStartTrigger: arg.EpochStartTrigger, + hdr: hdr, + hasher: arg.Hasher, + sigVerifier: arg.HeaderSigVerifier, + integrityVerifier: arg.HeaderIntegrityVerifier, + shardCoordinator: arg.ShardCoordinator, + validityAttester: arg.ValidityAttester, + epochStartTrigger: arg.EpochStartTrigger, + enableEpochsHandler: arg.EnableEpochsHandler, + epochChangeGracePeriodHandler: arg.EpochChangeGracePeriodHandler, } inHdr.processFields(arg.HdrBuff) @@ -74,7 +80,11 @@ func (inHdr *InterceptedHeader) CheckValidity() error { return err } - err = inHdr.sigVerifier.VerifyRandSeedAndLeaderSignature(inHdr.hdr) + return inHdr.verifySignatures() +} + +func (inHdr *InterceptedHeader) verifySignatures() error { + err := inHdr.sigVerifier.VerifyRandSeedAndLeaderSignature(inHdr.hdr) if err != nil { return err } @@ -95,7 +105,12 @@ func (inHdr *InterceptedHeader) isEpochCorrect() bool { if inHdr.hdr.GetRound() <= inHdr.epochStartTrigger.EpochStartRound() { return true } - if inHdr.hdr.GetRound() <= inHdr.epochStartTrigger.EpochFinalityAttestingRound()+process.EpochChangeGracePeriod { + gracePeriod, err := inHdr.epochChangeGracePeriodHandler.GetGracePeriodForEpoch(inHdr.hdr.GetEpoch()) + if err != nil { + log.Warn("isEpochCorrect", "epoch", inHdr.hdr.GetEpoch(), "error", err) + return false + } + if inHdr.hdr.GetRound() <= inHdr.epochStartTrigger.EpochFinalityAttestingRound()+uint64(gracePeriod) { return true } @@ -121,7 +136,7 @@ func (inHdr *InterceptedHeader) integrity() error { inHdr.epochStartTrigger.EpochFinalityAttestingRound()) } - err := checkHeaderHandler(inHdr.HeaderHandler()) + err := checkHeaderHandler(inHdr.HeaderHandler(), inHdr.enableEpochsHandler) if err != nil { return err } diff --git a/process/block/interceptedBlocks/interceptedBlockHeader_test.go b/process/block/interceptedBlocks/interceptedBlockHeader_test.go index a107e01dc3e..e3e03707f49 100644 --- a/process/block/interceptedBlocks/interceptedBlockHeader_test.go +++ b/process/block/interceptedBlocks/interceptedBlockHeader_test.go @@ -10,12 +10,18 @@ import ( "github.com/multiversx/mx-chain-core-go/data" dataBlock "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/graceperiod" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/process/mock" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var testMarshalizer = &mock.MarshalizerMock{} @@ -26,14 +32,17 @@ var hdrRound = uint64(67) var hdrEpoch = uint32(78) func createDefaultShardArgument() *interceptedBlocks.ArgInterceptedBlockHeader { + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) arg := &interceptedBlocks.ArgInterceptedBlockHeader{ - ShardCoordinator: mock.NewOneShardCoordinatorMock(), - Hasher: testHasher, - Marshalizer: testMarshalizer, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, - HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, - ValidityAttester: &mock.ValidityAttesterStub{}, - EpochStartTrigger: &mock.EpochStartTriggerStub{}, + ShardCoordinator: mock.NewOneShardCoordinatorMock(), + Hasher: testHasher, + Marshalizer: testMarshalizer, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, + ValidityAttester: &mock.ValidityAttesterStub{}, + EpochStartTrigger: &mock.EpochStartTriggerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + EpochChangeGracePeriodHandler: gracePeriod, } hdr := createMockShardHeader() @@ -43,14 +52,17 @@ func createDefaultShardArgument() *interceptedBlocks.ArgInterceptedBlockHeader { } func createDefaultShardArgumentWithV2Support() *interceptedBlocks.ArgInterceptedBlockHeader { + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) arg := &interceptedBlocks.ArgInterceptedBlockHeader{ - ShardCoordinator: mock.NewOneShardCoordinatorMock(), - Hasher: testHasher, - Marshalizer: &marshal.GogoProtoMarshalizer{}, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, - HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, - ValidityAttester: &mock.ValidityAttesterStub{}, - EpochStartTrigger: &mock.EpochStartTriggerStub{}, + ShardCoordinator: mock.NewOneShardCoordinatorMock(), + Hasher: testHasher, + Marshalizer: &marshal.GogoProtoMarshalizer{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, + ValidityAttester: &mock.ValidityAttesterStub{}, + EpochStartTrigger: &mock.EpochStartTriggerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + EpochChangeGracePeriodHandler: gracePeriod, } hdr := createMockShardHeader() arg.HdrBuff, _ = arg.Marshalizer.Marshal(hdr) @@ -83,7 +95,7 @@ func createMockShardHeader() *dataBlock.Header { } } -//------- TestNewInterceptedHeader +// ------- TestNewInterceptedHeader func TestNewInterceptedHeader_NilArgumentShouldErr(t *testing.T) { t.Parallel() @@ -167,7 +179,7 @@ func TestNewInterceptedHeader_MetachainForThisShardShouldWork(t *testing.T) { assert.True(t, inHdr.IsForCurrentShard()) } -//------- CheckValidity +// ------- Verify func TestInterceptedHeader_CheckValidityNilPubKeyBitmapShouldErr(t *testing.T) { t.Parallel() @@ -194,7 +206,7 @@ func TestInterceptedHeader_CheckValidityLeaderSignatureNotCorrectShouldErr(t *te expectedErr := errors.New("expected err") buff, _ := marshaller.Marshal(hdr) - arg.HeaderSigVerifier = &mock.HeaderSigVerifierStub{ + arg.HeaderSigVerifier = &consensus.HeaderSigVerifierMock{ VerifyRandSeedAndLeaderSignatureCalled: func(header data.HeaderHandler) error { return expectedErr }, @@ -226,6 +238,42 @@ func TestInterceptedHeader_CheckValidityLeaderSignatureOkShouldWork(t *testing.T assert.Nil(t, err) } +func TestInterceptedHeader_CheckValidityLeaderSignatureOkWithFlagActiveShouldWork(t *testing.T) { + t.Parallel() + + arg := createDefaultShardArgumentWithV2Support() + arg.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + wasVerifySignatureCalled := false + arg.HeaderSigVerifier = &consensus.HeaderSigVerifierMock{ + VerifySignatureCalled: func(header data.HeaderHandler) error { + wasVerifySignatureCalled = true + + return nil + }, + } + marshaller := arg.Marshalizer + hdr := &dataBlock.HeaderV2{ + Header: createMockShardHeader(), + ScheduledRootHash: []byte("root hash"), + ScheduledAccumulatedFees: big.NewInt(0), + ScheduledDeveloperFees: big.NewInt(0), + } + buff, _ := marshaller.Marshal(hdr) + + arg.HdrBuff = buff + inHdr, err := interceptedBlocks.NewInterceptedHeader(arg) + require.Nil(t, err) + require.NotNil(t, inHdr) + + err = inHdr.CheckValidity() + assert.Nil(t, err) + assert.True(t, wasVerifySignatureCalled) +} + func TestInterceptedHeader_ErrorInMiniBlockShouldErr(t *testing.T) { t.Parallel() @@ -305,7 +353,7 @@ func TestInterceptedHeader_CheckAgainstFinalHeaderErrorsShouldErr(t *testing.T) assert.Equal(t, expectedErr, err) } -//------- getters +// ------- getters func TestInterceptedHeader_Getters(t *testing.T) { t.Parallel() @@ -318,7 +366,7 @@ func TestInterceptedHeader_Getters(t *testing.T) { assert.Equal(t, hash, inHdr.Hash()) } -//------- IsInterfaceNil +// ------- IsInterfaceNil func TestInterceptedHeader_IsInterfaceNil(t *testing.T) { t.Parallel() diff --git a/process/block/interceptedBlocks/interceptedEquivalentProof.go b/process/block/interceptedBlocks/interceptedEquivalentProof.go new file mode 100644 index 00000000000..2b6792b8f20 --- /dev/null +++ b/process/block/interceptedBlocks/interceptedEquivalentProof.go @@ -0,0 +1,226 @@ +package interceptedBlocks + +import ( + "fmt" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/core/sync" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-vm-v1_2-go/ipc/marshaling" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/dataRetriever" + "github.com/multiversx/mx-chain-go/errors" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/sharding" +) + +const interceptedEquivalentProofType = "intercepted equivalent proof" + +// ArgInterceptedEquivalentProof is the argument used in the intercepted equivalent proof constructor +type ArgInterceptedEquivalentProof struct { + DataBuff []byte + Marshaller marshal.Marshalizer + Hasher hashing.Hasher + ShardCoordinator sharding.Coordinator + HeaderSigVerifier consensus.HeaderSigVerifier + Proofs dataRetriever.ProofsPool + ProofSizeChecker common.FieldsSizeChecker + KeyRWMutexHandler sync.KeyRWMutexHandler +} + +type interceptedEquivalentProof struct { + proof *block.HeaderProof + isForCurrentShard bool + headerSigVerifier consensus.HeaderSigVerifier + proofsPool dataRetriever.ProofsPool + marshaller marshaling.Marshalizer + hasher hashing.Hasher + hash []byte + proofSizeChecker common.FieldsSizeChecker + km sync.KeyRWMutexHandler +} + +// NewInterceptedEquivalentProof returns a new instance of interceptedEquivalentProof +func NewInterceptedEquivalentProof(args ArgInterceptedEquivalentProof) (*interceptedEquivalentProof, error) { + err := checkArgInterceptedEquivalentProof(args) + if err != nil { + return nil, err + } + + equivalentProof, err := createEquivalentProof(args.Marshaller, args.DataBuff) + if err != nil { + return nil, err + } + + hash := args.Hasher.Compute(string(args.DataBuff)) + + return &interceptedEquivalentProof{ + proof: equivalentProof, + isForCurrentShard: extractIsForCurrentShard(args.ShardCoordinator, equivalentProof), + headerSigVerifier: args.HeaderSigVerifier, + proofsPool: args.Proofs, + marshaller: args.Marshaller, + hasher: args.Hasher, + proofSizeChecker: args.ProofSizeChecker, + hash: hash, + km: args.KeyRWMutexHandler, + }, nil +} + +func checkArgInterceptedEquivalentProof(args ArgInterceptedEquivalentProof) error { + if len(args.DataBuff) == 0 { + return process.ErrNilBuffer + } + if check.IfNil(args.Marshaller) { + return process.ErrNilMarshalizer + } + if check.IfNil(args.ShardCoordinator) { + return process.ErrNilShardCoordinator + } + if check.IfNil(args.HeaderSigVerifier) { + return process.ErrNilHeaderSigVerifier + } + if check.IfNil(args.Proofs) { + return process.ErrNilProofsPool + } + if check.IfNil(args.Hasher) { + return process.ErrNilHasher + } + if check.IfNil(args.ProofSizeChecker) { + return errors.ErrNilFieldsSizeChecker + } + if check.IfNil(args.KeyRWMutexHandler) { + return process.ErrNilKeyRWMutexHandler + } + + return nil +} + +func createEquivalentProof(marshaller marshal.Marshalizer, buff []byte) (*block.HeaderProof, error) { + headerProof := &block.HeaderProof{} + err := marshaller.Unmarshal(headerProof, buff) + if err != nil { + return nil, err + } + + log.Trace("interceptedEquivalentProof successfully created", + "header hash", logger.DisplayByteSlice(headerProof.HeaderHash), + "header shard", headerProof.HeaderShardId, + "header epoch", headerProof.HeaderEpoch, + "header nonce", headerProof.HeaderNonce, + "header round", headerProof.HeaderRound, + "bitmap", logger.DisplayByteSlice(headerProof.PubKeysBitmap), + "signature", logger.DisplayByteSlice(headerProof.AggregatedSignature), + "isEpochStart", headerProof.IsStartOfEpoch, + ) + + return headerProof, nil +} + +func extractIsForCurrentShard(shardCoordinator sharding.Coordinator, equivalentProof *block.HeaderProof) bool { + proofShardId := equivalentProof.GetHeaderShardId() + if shardCoordinator.SelfId() == core.MetachainShardId { + return true + } + + if proofShardId == core.MetachainShardId { + return true + } + + return proofShardId == shardCoordinator.SelfId() +} + +// CheckValidity checks if the received proof is valid +func (iep *interceptedEquivalentProof) CheckValidity() error { + log.Debug("Checking intercepted equivalent proof validity", "proof header hash", iep.proof.HeaderHash) + err := iep.integrity() + if err != nil { + return err + } + + headerHash := string(iep.proof.GetHeaderHash()) + iep.km.Lock(headerHash) + defer iep.km.Unlock(headerHash) + + ok := iep.proofsPool.HasProof(iep.proof.GetHeaderShardId(), iep.proof.GetHeaderHash()) + if ok { + return common.ErrAlreadyExistingEquivalentProof + } + + err = iep.headerSigVerifier.VerifyHeaderProof(iep.proof) + if err != nil { + return err + } + + // also save the proof here in order to complete the flow under mutex lock + wasAdded := iep.proofsPool.AddProof(iep.proof) + if !wasAdded { + // with the current implementation, this should never happen + return common.ErrAlreadyExistingEquivalentProof + } + + return nil +} + +func (iep *interceptedEquivalentProof) integrity() error { + if !iep.proofSizeChecker.IsProofSizeValid(iep.proof) { + return ErrInvalidProof + } + + return nil +} + +// GetProof returns the underlying intercepted header proof +func (iep *interceptedEquivalentProof) GetProof() data.HeaderProofHandler { + return iep.proof +} + +// IsForCurrentShard returns true if the equivalent proof should be processed by the current shard +func (iep *interceptedEquivalentProof) IsForCurrentShard() bool { + return iep.isForCurrentShard +} + +// Hash returns the header hash the proof belongs to +func (iep *interceptedEquivalentProof) Hash() []byte { + return iep.hash +} + +// Type returns the type of this intercepted data +func (iep *interceptedEquivalentProof) Type() string { + return interceptedEquivalentProofType +} + +// Identifiers returns the identifiers used in requests +func (iep *interceptedEquivalentProof) Identifiers() [][]byte { + return [][]byte{ + iep.proof.HeaderHash, + // needed for the interceptor, when data is requested by nonce + []byte(common.GetEquivalentProofNonceShardKey(iep.proof.HeaderNonce, iep.proof.HeaderShardId)), + } +} + +// String returns the proof's most important fields as string +func (iep *interceptedEquivalentProof) String() string { + return fmt.Sprintf("bitmap=%s, signature=%s, hash=%s, epoch=%d, shard=%d, nonce=%d, round=%d, isEpochStart=%t", + logger.DisplayByteSlice(iep.proof.PubKeysBitmap), + logger.DisplayByteSlice(iep.proof.AggregatedSignature), + logger.DisplayByteSlice(iep.proof.HeaderHash), + iep.proof.HeaderEpoch, + iep.proof.HeaderShardId, + iep.proof.HeaderNonce, + iep.proof.HeaderRound, + iep.proof.IsStartOfEpoch, + ) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (iep *interceptedEquivalentProof) IsInterfaceNil() bool { + return iep == nil +} diff --git a/process/block/interceptedBlocks/interceptedEquivalentProof_test.go b/process/block/interceptedBlocks/interceptedEquivalentProof_test.go new file mode 100644 index 00000000000..bd2d741db18 --- /dev/null +++ b/process/block/interceptedBlocks/interceptedEquivalentProof_test.go @@ -0,0 +1,377 @@ +package interceptedBlocks + +import ( + "errors" + "fmt" + "sync" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + coreSync "github.com/multiversx/mx-chain-core-go/core/sync" + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + errErd "github.com/multiversx/mx-chain-go/errors" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" +) + +var ( + expectedErr = errors.New("expected error") + testMarshaller = &marshallerMock.MarshalizerMock{} + providedEpoch = uint32(123) + providedNonce = uint64(345) + providedShard = uint32(0) + providedRound = uint64(123456) +) + +func createMockDataBuffWithHash(headerHash []byte) []byte { + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: headerHash, + HeaderEpoch: providedEpoch, + HeaderNonce: providedNonce, + HeaderShardId: providedShard, + HeaderRound: providedRound, + } + + dataBuff, _ := testMarshaller.Marshal(proof) + return dataBuff +} + +func createMockDataBuff() []byte { + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderEpoch: providedEpoch, + HeaderNonce: providedNonce, + HeaderShardId: providedShard, + HeaderRound: providedRound, + } + + dataBuff, _ := testMarshaller.Marshal(proof) + return dataBuff +} + +func createMockArgInterceptedEquivalentProof() ArgInterceptedEquivalentProof { + return ArgInterceptedEquivalentProof{ + DataBuff: createMockDataBuff(), + Marshaller: testMarshaller, + ShardCoordinator: &mock.ShardCoordinatorMock{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + Proofs: &dataRetriever.ProofsPoolMock{}, + Hasher: &hashingMocks.HasherMock{}, + ProofSizeChecker: &testscommon.FieldsSizeCheckerMock{}, + KeyRWMutexHandler: coreSync.NewKeyRWMutex(), + } +} + +func TestInterceptedEquivalentProof_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var iep *interceptedEquivalentProof + require.True(t, iep.IsInterfaceNil()) + + iep, _ = NewInterceptedEquivalentProof(createMockArgInterceptedEquivalentProof()) + require.False(t, iep.IsInterfaceNil()) +} + +func TestNewInterceptedEquivalentProof(t *testing.T) { + t.Parallel() + + t.Run("nil DataBuff should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.DataBuff = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilBuffer, err) + require.Nil(t, iep) + }) + t.Run("nil Marshaller should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Marshaller = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilMarshalizer, err) + require.Nil(t, iep) + }) + t.Run("nil ShardCoordinator should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.ShardCoordinator = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilShardCoordinator, err) + require.Nil(t, iep) + }) + t.Run("nil HeaderSigVerifier should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.HeaderSigVerifier = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilHeaderSigVerifier, err) + require.Nil(t, iep) + }) + t.Run("nil proofs pool should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Proofs = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilProofsPool, err) + require.Nil(t, iep) + }) + t.Run("nil Hasher should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Hasher = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilHasher, err) + require.Nil(t, iep) + }) + t.Run("unmarshal error should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Marshaller = &marshallerMock.MarshalizerStub{ + UnmarshalCalled: func(obj interface{}, buff []byte) error { + return expectedErr + }, + } + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, expectedErr, err) + require.Nil(t, iep) + }) + t.Run("nil ProofSizeChecker should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.ProofSizeChecker = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, errErd.ErrNilFieldsSizeChecker, err) + require.Nil(t, iep) + }) + t.Run("nil KeyRWMutexHandler should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.KeyRWMutexHandler = nil + iep, err := NewInterceptedEquivalentProof(args) + require.Equal(t, process.ErrNilKeyRWMutexHandler, err) + require.Nil(t, iep) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + iep, err := NewInterceptedEquivalentProof(createMockArgInterceptedEquivalentProof()) + require.NoError(t, err) + require.NotNil(t, iep) + }) +} + +func TestInterceptedEquivalentProof_CheckValidity(t *testing.T) { + t.Parallel() + + t.Run("invalid proof should error", func(t *testing.T) { + t.Parallel() + + // no header hash + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + } + args := createMockArgInterceptedEquivalentProof() + args.DataBuff, _ = args.Marshaller.Marshal(proof) + args.ProofSizeChecker = &testscommon.FieldsSizeCheckerMock{ + IsProofSizeValidCalled: func(proof data.HeaderProofHandler) bool { + return false + }, + } + + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + err = iep.CheckValidity() + require.Equal(t, ErrInvalidProof, err) + }) + t.Run("already exiting proof should error", func(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProof() + args.Proofs = &dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true + }, + } + + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + err = iep.CheckValidity() + require.Equal(t, common.ErrAlreadyExistingEquivalentProof, err) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + iep, err := NewInterceptedEquivalentProof(createMockArgInterceptedEquivalentProof()) + require.NoError(t, err) + + err = iep.CheckValidity() + require.NoError(t, err) + }) + t.Run("concurrent calls should work", func(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r != nil { + require.Fail(t, "should have not panicked") + } + }() + + km := coreSync.NewKeyRWMutex() + + numCalls := 1000 + wg := sync.WaitGroup{} + wg.Add(numCalls) + + for i := 0; i < numCalls; i++ { + go func(idx int) { + hash := fmt.Sprintf("hash_%d", idx%5) // make sure hashes repeat + + args := createMockArgInterceptedEquivalentProof() + args.KeyRWMutexHandler = km + args.DataBuff = createMockDataBuffWithHash([]byte(hash)) + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + _ = iep.CheckValidity() + + wg.Done() + }(i) + } + + wg.Wait() + }) +} + +func TestInterceptedEquivalentProof_IsForCurrentShard(t *testing.T) { + t.Parallel() + + t.Run("meta should return true", func(t *testing.T) { + t.Parallel() + + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderShardId: core.MetachainShardId, + } + args := createMockArgInterceptedEquivalentProof() + args.DataBuff, _ = args.Marshaller.Marshal(proof) + args.ShardCoordinator = &mock.ShardCoordinatorMock{ShardID: core.MetachainShardId} + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + require.True(t, iep.IsForCurrentShard()) + }) + t.Run("meta proof on different shard should return true", func(t *testing.T) { + t.Parallel() + + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderShardId: core.MetachainShardId, + } + args := createMockArgInterceptedEquivalentProof() + args.DataBuff, _ = args.Marshaller.Marshal(proof) + args.ShardCoordinator = &mock.ShardCoordinatorMock{ShardID: 0} + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + require.True(t, iep.IsForCurrentShard()) + }) + t.Run("self shard id return true", func(t *testing.T) { + t.Parallel() + + selfShardId := uint32(1234) + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderShardId: selfShardId, + } + args := createMockArgInterceptedEquivalentProof() + args.DataBuff, _ = args.Marshaller.Marshal(proof) + args.ShardCoordinator = &mock.ShardCoordinatorMock{ShardID: selfShardId} + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + require.True(t, iep.IsForCurrentShard()) + }) + t.Run("other shard id return true", func(t *testing.T) { + t.Parallel() + + selfShardId := uint32(1234) + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderShardId: selfShardId, + } + args := createMockArgInterceptedEquivalentProof() + args.DataBuff, _ = args.Marshaller.Marshal(proof) + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + require.False(t, iep.IsForCurrentShard()) + }) +} + +func TestInterceptedEquivalentProof_Getters(t *testing.T) { + t.Parallel() + + proof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderEpoch: 123, + HeaderNonce: 345, + HeaderShardId: 0, + HeaderRound: 123456, + IsStartOfEpoch: false, + } + args := createMockArgInterceptedEquivalentProof() + args.DataBuff, _ = args.Marshaller.Marshal(proof) + hash := args.Hasher.Compute(string(args.DataBuff)) + iep, err := NewInterceptedEquivalentProof(args) + require.NoError(t, err) + + require.Equal(t, proof, iep.GetProof()) // pointer testing + require.Equal(t, hash, iep.Hash()) + require.Equal(t, [][]byte{ + proof.HeaderHash, + []byte(common.GetEquivalentProofNonceShardKey(proof.HeaderNonce, proof.HeaderShardId)), + }, iep.Identifiers()) + require.Equal(t, interceptedEquivalentProofType, iep.Type()) + expectedStr := fmt.Sprintf("bitmap=%s, signature=%s, hash=%s, epoch=123, shard=0, nonce=345, round=123456, isEpochStart=false", + logger.DisplayByteSlice(proof.PubKeysBitmap), + logger.DisplayByteSlice(proof.AggregatedSignature), + logger.DisplayByteSlice(proof.HeaderHash)) + require.Equal(t, expectedStr, iep.String()) +} diff --git a/process/block/interceptedBlocks/interceptedMetaBlockHeader.go b/process/block/interceptedBlocks/interceptedMetaBlockHeader.go index 415e2da3967..afe966e13f8 100644 --- a/process/block/interceptedBlocks/interceptedMetaBlockHeader.go +++ b/process/block/interceptedBlocks/interceptedMetaBlockHeader.go @@ -8,9 +8,11 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" - logger "github.com/multiversx/mx-chain-logger-go" ) var _ process.HdrValidatorHandler = (*InterceptedMetaHeader)(nil) @@ -20,14 +22,15 @@ var log = logger.GetOrCreate("process/block/interceptedBlocks") // InterceptedMetaHeader represents the wrapper over the meta block header struct type InterceptedMetaHeader struct { - hdr data.MetaHeaderHandler - sigVerifier process.InterceptedHeaderSigVerifier - integrityVerifier process.HeaderIntegrityVerifier - hasher hashing.Hasher - shardCoordinator sharding.Coordinator - hash []byte - validityAttester process.ValidityAttester - epochStartTrigger process.EpochStartTriggerHandler + hdr data.MetaHeaderHandler + sigVerifier process.InterceptedHeaderSigVerifier + integrityVerifier process.HeaderIntegrityVerifier + hasher hashing.Hasher + shardCoordinator sharding.Coordinator + hash []byte + validityAttester process.ValidityAttester + epochStartTrigger process.EpochStartTriggerHandler + enableEpochsHandler common.EnableEpochsHandler } // NewInterceptedMetaHeader creates a new instance of InterceptedMetaHeader struct @@ -43,13 +46,14 @@ func NewInterceptedMetaHeader(arg *ArgInterceptedBlockHeader) (*InterceptedMetaH } inHdr := &InterceptedMetaHeader{ - hdr: hdr, - hasher: arg.Hasher, - sigVerifier: arg.HeaderSigVerifier, - integrityVerifier: arg.HeaderIntegrityVerifier, - shardCoordinator: arg.ShardCoordinator, - validityAttester: arg.ValidityAttester, - epochStartTrigger: arg.EpochStartTrigger, + hdr: hdr, + hasher: arg.Hasher, + sigVerifier: arg.HeaderSigVerifier, + integrityVerifier: arg.HeaderIntegrityVerifier, + shardCoordinator: arg.ShardCoordinator, + validityAttester: arg.ValidityAttester, + epochStartTrigger: arg.EpochStartTrigger, + enableEpochsHandler: arg.EnableEpochsHandler, } inHdr.processFields(arg.HdrBuff) @@ -84,6 +88,8 @@ func (imh *InterceptedMetaHeader) HeaderHandler() data.HeaderHandler { // CheckValidity checks if the received meta header is valid (not nil fields, valid sig and so on) func (imh *InterceptedMetaHeader) CheckValidity() error { + log.Trace("CheckValidity for header with", "epoch", imh.hdr.GetEpoch(), "hash", logger.DisplayByteSlice(imh.hash)) + err := imh.integrity() if err != nil { return err @@ -137,7 +143,7 @@ func (imh *InterceptedMetaHeader) isMetaHeaderEpochOutOfRange() bool { // integrity checks the integrity of the meta header block wrapper func (imh *InterceptedMetaHeader) integrity() error { - err := checkHeaderHandler(imh.HeaderHandler()) + err := checkHeaderHandler(imh.HeaderHandler(), imh.enableEpochsHandler) if err != nil { return err } diff --git a/process/block/interceptedBlocks/interceptedMetaBlockHeader_test.go b/process/block/interceptedBlocks/interceptedMetaBlockHeader_test.go index 99fc49d1dd3..16823210f87 100644 --- a/process/block/interceptedBlocks/interceptedMetaBlockHeader_test.go +++ b/process/block/interceptedBlocks/interceptedMetaBlockHeader_test.go @@ -5,22 +5,35 @@ import ( "math/big" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" dataBlock "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common/graceperiod" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/process/mock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/sharding" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" ) func createDefaultMetaArgument() *interceptedBlocks.ArgInterceptedBlockHeader { + shardCoordinator := mock.NewOneShardCoordinatorMock() + return createMetaArgumentWithShardCoordinator(shardCoordinator) +} + +func createMetaArgumentWithShardCoordinator(shardCoordinator sharding.Coordinator) *interceptedBlocks.ArgInterceptedBlockHeader { + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) arg := &interceptedBlocks.ArgInterceptedBlockHeader{ - ShardCoordinator: mock.NewOneShardCoordinatorMock(), + ShardCoordinator: shardCoordinator, Hasher: testHasher, Marshalizer: testMarshalizer, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, ValidityAttester: &mock.ValidityAttesterStub{}, EpochStartTrigger: &mock.EpochStartTriggerStub{ @@ -28,6 +41,8 @@ func createDefaultMetaArgument() *interceptedBlocks.ArgInterceptedBlockHeader { return hdrEpoch }, }, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + EpochChangeGracePeriodHandler: gracePeriod, } hdr := createMockMetaHeader() @@ -60,7 +75,7 @@ func createMockMetaHeader() *dataBlock.MetaBlock { } } -//------- TestNewInterceptedHeader +// ------- TestNewInterceptedHeader func TestNewInterceptedMetaHeader_NilArgumentShouldErr(t *testing.T) { t.Parallel() @@ -95,7 +110,7 @@ func TestNewInterceptedMetaHeader_ShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- CheckValidity +// ------- CheckValidity func TestInterceptedMetaHeader_CheckValidityNilPubKeyBitmapShouldErr(t *testing.T) { t.Parallel() @@ -128,7 +143,10 @@ func TestInterceptedMetaHeader_ErrorInMiniBlockShouldErr(t *testing.T) { } buff, _ := testMarshalizer.Marshal(hdr) - arg := createDefaultMetaArgument() + shardCoordinator := mock.NewOneShardCoordinatorMock() + _ = shardCoordinator.SetSelfId(core.MetachainShardId) + + arg := createMetaArgumentWithShardCoordinator(shardCoordinator) arg.HdrBuff = buff inHdr, _ := interceptedBlocks.NewInterceptedMetaHeader(arg) @@ -182,7 +200,7 @@ func TestInterceptedMetaHeader_CheckAgainstFinalHeaderAttesterFailsShouldErr(t * assert.Equal(t, expectedErr, err) } -//------- getters +// ------- getters func TestInterceptedMetaHeader_Getters(t *testing.T) { t.Parallel() @@ -204,7 +222,7 @@ func TestInterceptedMetaHeader_CheckValidityLeaderSignatureNotCorrectShouldErr(t buff, _ := testMarshalizer.Marshal(hdr) arg := createDefaultMetaArgument() - arg.HeaderSigVerifier = &mock.HeaderSigVerifierStub{ + arg.HeaderSigVerifier = &consensus.HeaderSigVerifierMock{ VerifyRandSeedAndLeaderSignatureCalled: func(header data.HeaderHandler) error { return expectedErr }, @@ -283,7 +301,7 @@ func TestInterceptedMetaHeader_isMetaHeaderEpochOutOfRange(t *testing.T) { }) } -//------- IsInterfaceNil +// ------- IsInterfaceNil func TestInterceptedMetaHeader_IsInterfaceNil(t *testing.T) { t.Parallel() diff --git a/process/block/interceptedBlocks/interceptedMiniblock_test.go b/process/block/interceptedBlocks/interceptedMiniblock_test.go index 57d53ec251d..46b489b259d 100644 --- a/process/block/interceptedBlocks/interceptedMiniblock_test.go +++ b/process/block/interceptedBlocks/interceptedMiniblock_test.go @@ -5,10 +5,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/process/mock" - "github.com/stretchr/testify/assert" ) func createDefaultMiniblockArgument() *interceptedBlocks.ArgInterceptedMiniblock { @@ -69,7 +70,7 @@ func TestNewInterceptedMiniblock_ShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- CheckValidity +//------- Verify func TestInterceptedMiniblock_InvalidReceiverShardIdShouldErr(t *testing.T) { t.Parallel() diff --git a/process/block/metablock.go b/process/block/metablock.go index d9832a27fc8..bbdf27b54fd 100644 --- a/process/block/metablock.go +++ b/process/block/metablock.go @@ -3,6 +3,7 @@ package block import ( "bytes" "encoding/hex" + "errors" "fmt" "math/big" "sync" @@ -13,6 +14,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/headerVersionData" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/holders" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -22,7 +25,6 @@ import ( "github.com/multiversx/mx-chain-go/process/block/helpers" "github.com/multiversx/mx-chain-go/process/block/processedMb" "github.com/multiversx/mx-chain-go/state" - logger "github.com/multiversx/mx-chain-logger-go" ) const firstHeaderNonce = uint64(1) @@ -42,7 +44,6 @@ type metaProcessor struct { validatorStatisticsProcessor process.ValidatorStatisticsProcessor shardsHeadersNonce *sync.Map shardBlockFinality uint32 - chRcvAllHdrs chan bool headersCounter *headersCounter } @@ -123,6 +124,7 @@ func NewMetaProcessor(arguments ArgMetaProcessor) (*metaProcessor, error) { enableEpochsHandler: arguments.CoreComponents.EnableEpochsHandler(), roundNotifier: arguments.CoreComponents.RoundNotifier(), enableRoundsHandler: arguments.CoreComponents.EnableRoundsHandler(), + epochChangeGracePeriodHandler: arguments.CoreComponents.EpochChangeGracePeriodHandler(), vmContainerFactory: arguments.VMContainersFactory, vmContainer: arguments.VmContainer, processDataTriesOnCommitEpoch: arguments.Config.Debug.EpochStart.ProcessDataTrieOnCommitEpoch, @@ -139,6 +141,7 @@ func NewMetaProcessor(arguments ArgMetaProcessor) (*metaProcessor, error) { managedPeersHolder: arguments.ManagedPeersHolder, sentSignaturesTracker: arguments.SentSignaturesTracker, extraDelayRequestBlockInfo: time.Duration(arguments.Config.EpochStartConfig.ExtraDelayForRequestBlockInfoInMilliseconds) * time.Millisecond, + proofsPool: arguments.DataComponents.Datapool().Proofs(), } mp := metaProcessor{ @@ -173,6 +176,8 @@ func NewMetaProcessor(arguments ArgMetaProcessor) (*metaProcessor, error) { headersPool := mp.dataPool.Headers() headersPool.RegisterHandler(mp.receivedShardHeader) + mp.proofsPool.RegisterHandler(mp.checkReceivedProofIfAttestingIsNeeded) + mp.chRcvAllHdrs = make(chan bool) mp.shardBlockFinality = process.BlockFinality @@ -201,7 +206,7 @@ func (mp *metaProcessor) ProcessBlock( err := mp.checkBlockValidity(headerHandler, bodyHandler) if err != nil { - if err == process.ErrBlockHashDoesNotMatch { + if errors.Is(err, process.ErrBlockHashDoesNotMatch) { log.Debug("requested missing meta header", "hash", headerHandler.GetPrevHash(), "for shard", headerHandler.GetShardID(), @@ -301,7 +306,7 @@ func (mp *metaProcessor) ProcessBlock( } mp.txCoordinator.RequestBlockTransactions(body) - requestedShardHdrs, requestedFinalityAttestingShardHdrs := mp.requestShardHeaders(header) + requestedShardHdrs, requestedFinalityAttestingShardHdrs, requestedProofs := mp.requestShardHeaders(header) if haveTime() < 0 { return process.ErrTimeIsOut @@ -312,7 +317,7 @@ func (mp *metaProcessor) ProcessBlock( return err } - haveMissingShardHeaders := requestedShardHdrs > 0 || requestedFinalityAttestingShardHdrs > 0 + haveMissingShardHeaders := requestedShardHdrs > 0 || requestedFinalityAttestingShardHdrs > 0 || requestedProofs > 0 if haveMissingShardHeaders { if requestedShardHdrs > 0 { log.Debug("requested missing shard headers", @@ -324,11 +329,17 @@ func (mp *metaProcessor) ProcessBlock( "num finality shard headers", requestedFinalityAttestingShardHdrs, ) } + if requestedProofs > 0 { + log.Debug("requested missing shard header proofs", + "num proofs", requestedProofs, + ) + } err = mp.waitForBlockHeaders(haveTime()) mp.hdrsForCurrBlock.mutHdrsForBlock.RLock() missingShardHdrs := mp.hdrsForCurrBlock.missingHdrs + missingProofs := mp.hdrsForCurrBlock.missingProofs mp.hdrsForCurrBlock.mutHdrsForBlock.RUnlock() mp.hdrsForCurrBlock.resetMissingHdrs() @@ -338,6 +349,11 @@ func (mp *metaProcessor) ProcessBlock( "num headers", requestedShardHdrs-missingShardHdrs, ) } + if requestedProofs > 0 { + log.Debug("received missing shard header proofs", + "num proofs", requestedProofs-missingProofs, + ) + } if err != nil { return err @@ -652,7 +668,15 @@ func (mp *metaProcessor) indexBlock( log.Debug("indexed block", "hash", headerHash, "nonce", metaBlock.GetNonce(), "round", metaBlock.GetRound()) - indexRoundInfo(mp.outportHandler, mp.nodesCoordinator, core.MetachainShardId, metaBlock, lastMetaBlock, argSaveBlock.SignersIndexes) + indexRoundInfo( + mp.outportHandler, + mp.nodesCoordinator, + core.MetachainShardId, + metaBlock, + lastMetaBlock, + argSaveBlock.SignersIndexes, + mp.enableEpochsHandler, + ) if metaBlock.GetNonce() != 1 && !metaBlock.IsStartOfEpochBlock() { return @@ -695,7 +719,7 @@ func (mp *metaProcessor) RestoreBlockIntoPools(headerHandler data.HeaderHandler, headersPool.AddHeader(hdrHash, shardHeader) - hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(shardHeader.GetShardID()) + hdrNonceHashDataUnit := dataRetriever.GetHdrNonceHashDataUnit(shardHeader.GetShardID()) storer, errNotCritical := mp.store.GetStorer(hdrNonceHashDataUnit) if errNotCritical != nil { log.Debug("storage unit not found", "unit", hdrNonceHashDataUnit, "error", errNotCritical.Error()) @@ -718,6 +742,33 @@ func (mp *metaProcessor) RestoreBlockIntoPools(headerHandler data.HeaderHandler, return nil } +func (mp *metaProcessor) updateHeaderForEpochStartIfNeeded(metaHdr *block.MetaBlock) error { + isEpochStart := mp.epochStartTrigger.IsEpochStart() + if !isEpochStart { + return nil + } + return mp.updateEpochStartHeader(metaHdr) +} + +func (mp *metaProcessor) createBody(metaHdr *block.MetaBlock, haveTime func() bool) (data.BodyHandler, error) { + isEpochStart := mp.epochStartTrigger.IsEpochStart() + var body data.BodyHandler + var err error + if isEpochStart { + body, err = mp.createEpochStartBody(metaHdr) + if err != nil { + return nil, err + } + } else { + body, err = mp.createBlockBody(metaHdr, haveTime) + if err != nil { + return nil, err + } + } + + return body, nil +} + // CreateBlock creates the final block and header for the current round func (mp *metaProcessor) CreateBlock( initialHdr data.HeaderHandler, @@ -750,31 +801,19 @@ func (mp *metaProcessor) CreateBlock( return nil, nil, err } - if mp.epochStartTrigger.IsEpochStart() { - err = mp.updateEpochStartHeader(metaHdr) - if err != nil { - return nil, nil, err - } - - err = mp.blockChainHook.SetCurrentHeader(metaHdr) - if err != nil { - return nil, nil, err - } + err = mp.updateHeaderForEpochStartIfNeeded(metaHdr) + if err != nil { + return nil, nil, err + } - body, err = mp.createEpochStartBody(metaHdr) - if err != nil { - return nil, nil, err - } - } else { - err = mp.blockChainHook.SetCurrentHeader(metaHdr) - if err != nil { - return nil, nil, err - } + err = mp.blockChainHook.SetCurrentHeader(metaHdr) + if err != nil { + return nil, nil, err + } - body, err = mp.createBlockBody(metaHdr, haveTime) - if err != nil { - return nil, nil, err - } + body, err = mp.createBody(metaHdr, haveTime) + if err != nil { + return nil, nil, err } body, err = mp.applyBodyToHeader(metaHdr, body) @@ -1094,8 +1133,23 @@ func (mp *metaProcessor) createAndProcessCrossMiniBlocksDstMe( continue } + shouldCheckProof := mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, currShardHdr.GetEpoch()) + if shouldCheckProof { + hasProofForHdr := mp.proofsPool.HasProof(currShardHdr.GetShardID(), orderedHdrsHashes[i]) + if !hasProofForHdr { + log.Trace("no proof for shard header", + "shard", currShardHdr.GetShardID(), + "hash", logger.DisplayByteSlice(orderedHdrsHashes[i]), + ) + continue + } + } + if len(currShardHdr.GetMiniBlockHeadersWithDst(mp.shardCoordinator.SelfId())) == 0 { - mp.hdrsForCurrBlock.hdrHashAndInfo[string(orderedHdrsHashes[i])] = &hdrInfo{hdr: currShardHdr, usedInBlock: true} + mp.hdrsForCurrBlock.hdrHashAndInfo[string(orderedHdrsHashes[i])] = &hdrInfo{ + hdr: currShardHdr, + usedInBlock: true, + } hdrsAdded++ hdrsAddedForShard[currShardHdr.GetShardID()]++ lastShardHdr[currShardHdr.GetShardID()] = currShardHdr @@ -1124,7 +1178,7 @@ func (mp *metaProcessor) createAndProcessCrossMiniBlocksDstMe( // shard header must be processed completely errAccountState := mp.accountsDB[state.UserAccountsState].RevertToSnapshot(snapshot) if errAccountState != nil { - // TODO: evaluate if reloading the trie from disk will might solve the problem + // TODO: evaluate if reloading the trie from disk might solve the problem log.Warn("accounts.RevertToSnapshot", "error", errAccountState.Error()) } continue @@ -1134,7 +1188,10 @@ func (mp *metaProcessor) createAndProcessCrossMiniBlocksDstMe( miniBlocks = append(miniBlocks, currMBProcessed...) txsAdded += currTxsAdded - mp.hdrsForCurrBlock.hdrHashAndInfo[string(orderedHdrsHashes[i])] = &hdrInfo{hdr: currShardHdr, usedInBlock: true} + mp.hdrsForCurrBlock.hdrHashAndInfo[string(orderedHdrsHashes[i])] = &hdrInfo{ + hdr: currShardHdr, + usedInBlock: true, + } hdrsAdded++ hdrsAddedForShard[currShardHdr.GetShardID()]++ @@ -1283,21 +1340,8 @@ func (mp *metaProcessor) CommitBlock( "nonce", mp.forkDetector.GetHighestFinalBlockNonce(), ) - lastHeader := mp.blockChain.GetCurrentBlockHeader() - lastMetaBlock, ok := lastHeader.(data.MetaHeaderHandler) - if !ok { - if headerHandler.GetNonce() == firstHeaderNonce { - log.Debug("metaBlock.CommitBlock - nil current block header, this is expected at genesis time") - } else { - log.Error("metaBlock.CommitBlock - nil current block header, last current header should have not been nil") - } - } - lastMetaBlockHash := mp.blockChain.GetCurrentBlockHeaderHash() - if mp.lastRestartNonce == 0 { - mp.lastRestartNonce = header.GetNonce() - } - - mp.updateState(lastMetaBlock, lastMetaBlockHash) + finalMetaBlock, finalMetaBlockHash := mp.computeFinalMetaBlock(header, headerHash) + mp.updateState(finalMetaBlock, finalMetaBlockHash) committedRootHash, err := mp.accountsDB[state.UserAccountsState].RootHash() if err != nil { @@ -1311,12 +1355,12 @@ func (mp *metaProcessor) CommitBlock( mp.blockChain.SetCurrentBlockHeaderHash(headerHash) - if !check.IfNil(lastMetaBlock) && lastMetaBlock.IsStartOfEpochBlock() { + if !check.IfNil(finalMetaBlock) && finalMetaBlock.IsStartOfEpochBlock() { mp.blockTracker.CleanupInvalidCrossHeaders(header.Epoch, header.Round) } // TODO: Should be sent also validatorInfoTxs alongside rewardsTxs -> mp.validatorInfoCreator.GetValidatorInfoTxs(body) ? - mp.indexBlock(header, headerHash, body, lastMetaBlock, notarizedHeadersHashes, rewardsTxs) + mp.indexBlock(header, headerHash, body, finalMetaBlock, notarizedHeadersHashes, rewardsTxs) mp.recordBlockInHistory(headerHash, headerHandler, bodyHandler) highestFinalBlockNonce := mp.forkDetector.GetHighestFinalBlockNonce() @@ -1337,6 +1381,7 @@ func (mp *metaProcessor) CommitBlock( headerHash, numShardHeadersFromPool, mp.blockTracker, + mp.dataPool, ) }() @@ -1378,6 +1423,38 @@ func (mp *metaProcessor) CommitBlock( return nil } +func (mp *metaProcessor) computeFinalMetaBlock(metaBlock *block.MetaBlock, metaBlockHash []byte) (data.MetaHeaderHandler, []byte) { + lastHeader := mp.blockChain.GetCurrentBlockHeader() + lastMetaBlock, ok := lastHeader.(data.MetaHeaderHandler) + if !ok { + if metaBlock.GetNonce() == firstHeaderNonce { + log.Debug("metaBlock.CommitBlock - nil current block header, this is expected at genesis time") + } else { + log.Error("metaBlock.CommitBlock - nil current block header, last current header should have not been nil") + } + } + lastMetaBlockHash := mp.blockChain.GetCurrentBlockHeaderHash() + if mp.lastRestartNonce == 0 { + mp.lastRestartNonce = metaBlock.GetNonce() + } + + finalMetaBlock := lastMetaBlock + finalMetaBlockHash := lastMetaBlockHash + isBlockAfterAndromedaFlag := !check.IfNil(finalMetaBlock) && + mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, finalMetaBlock.GetEpoch()) + if isBlockAfterAndromedaFlag { + // for the first block we need to update both the state of the previous one and for current + if common.IsEpochChangeBlockForFlagActivation(metaBlock, mp.enableEpochsHandler, common.AndromedaFlag) { + mp.updateState(lastMetaBlock, lastMetaBlockHash) + } + + finalMetaBlock = metaBlock + finalMetaBlockHash = metaBlockHash + } + + return finalMetaBlock, finalMetaBlockHash +} + func (mp *metaProcessor) updateCrossShardInfo(metaBlock *block.MetaBlock) ([]string, error) { mp.hdrsForCurrBlock.mutHdrsForBlock.RLock() defer mp.hdrsForCurrBlock.mutHdrsForBlock.RUnlock() @@ -1441,15 +1518,15 @@ func (mp *metaProcessor) displayPoolsInfo() { mp.displayMiniBlocksPool() } -func (mp *metaProcessor) updateState(lastMetaBlock data.MetaHeaderHandler, lastMetaBlockHash []byte) { - if check.IfNil(lastMetaBlock) { +func (mp *metaProcessor) updateState(metaBlock data.MetaHeaderHandler, metaBlockHash []byte) { + if check.IfNil(metaBlock) { log.Debug("updateState nil header") return } - mp.validatorStatisticsProcessor.SetLastFinalizedRootHash(lastMetaBlock.GetValidatorStatsRootHash()) + mp.validatorStatisticsProcessor.SetLastFinalizedRootHash(metaBlock.GetValidatorStatsRootHash()) - prevMetaBlockHash := lastMetaBlock.GetPrevHash() + prevMetaBlockHash := metaBlock.GetPrevHash() prevMetaBlock, errNotCritical := process.GetMetaHeader( prevMetaBlockHash, mp.dataPool.Headers(), @@ -1461,20 +1538,20 @@ func (mp *metaProcessor) updateState(lastMetaBlock data.MetaHeaderHandler, lastM return } - if lastMetaBlock.IsStartOfEpochBlock() { + if metaBlock.IsStartOfEpochBlock() { log.Debug("trie snapshot", - "rootHash", lastMetaBlock.GetRootHash(), + "rootHash", metaBlock.GetRootHash(), "prevRootHash", prevMetaBlock.GetRootHash(), - "validatorStatsRootHash", lastMetaBlock.GetValidatorStatsRootHash()) - mp.accountsDB[state.UserAccountsState].SnapshotState(lastMetaBlock.GetRootHash(), lastMetaBlock.GetEpoch()) - mp.accountsDB[state.PeerAccountsState].SnapshotState(lastMetaBlock.GetValidatorStatsRootHash(), lastMetaBlock.GetEpoch()) + "validatorStatsRootHash", metaBlock.GetValidatorStatsRootHash()) + mp.accountsDB[state.UserAccountsState].SnapshotState(metaBlock.GetRootHash(), metaBlock.GetEpoch()) + mp.accountsDB[state.PeerAccountsState].SnapshotState(metaBlock.GetValidatorStatsRootHash(), metaBlock.GetEpoch()) go func() { - metaBlock, ok := lastMetaBlock.(*block.MetaBlock) + metaBlock, ok := metaBlock.(*block.MetaBlock) if !ok { - log.Warn("cannot commit Trie Epoch Root Hash: lastMetaBlock is not *block.MetaBlock") + log.Warn("cannot commit Trie Epoch Root Hash: metaBlock is not *block.MetaBlock") return } - err := mp.commitTrieEpochRootHashIfNeeded(metaBlock, lastMetaBlock.GetRootHash()) + err := mp.commitTrieEpochRootHashIfNeeded(metaBlock, metaBlock.GetRootHash()) if err != nil { log.Warn("couldn't commit trie checkpoint", "epoch", metaBlock.Epoch, "error", err) } @@ -1482,21 +1559,26 @@ func (mp *metaProcessor) updateState(lastMetaBlock data.MetaHeaderHandler, lastM } mp.updateStateStorage( - lastMetaBlock, - lastMetaBlock.GetRootHash(), + metaBlock, + metaBlock.GetRootHash(), prevMetaBlock.GetRootHash(), mp.accountsDB[state.UserAccountsState], ) mp.updateStateStorage( - lastMetaBlock, - lastMetaBlock.GetValidatorStatsRootHash(), + metaBlock, + metaBlock.GetValidatorStatsRootHash(), prevMetaBlock.GetValidatorStatsRootHash(), mp.accountsDB[state.PeerAccountsState], ) - mp.setFinalizedHeaderHashInIndexer(lastMetaBlock.GetPrevHash()) - mp.blockChain.SetFinalBlockInfo(lastMetaBlock.GetNonce(), lastMetaBlockHash, lastMetaBlock.GetRootHash()) + outportFinalizedHeaderHash := metaBlockHash + if !common.IsFlagEnabledAfterEpochsStartBlock(metaBlock, mp.enableEpochsHandler, common.AndromedaFlag) { + outportFinalizedHeaderHash = metaBlock.GetPrevHash() + } + mp.setFinalizedHeaderHashInIndexer(outportFinalizedHeaderHash) + + mp.blockChain.SetFinalBlockInfo(metaBlock.GetNonce(), metaBlockHash, metaBlock.GetRootHash()) } func (mp *metaProcessor) getLastSelfNotarizedHeaderByShard( @@ -1736,7 +1818,10 @@ func (mp *metaProcessor) getLastCrossNotarizedShardHdrs() (map[uint32]data.Heade log.Debug("lastCrossNotarizedHeader for shard", "shardID", shardID, "hash", hash) lastCrossNotarizedHeader[shardID] = lastCrossNotarizedHeaderForShard usedInBlock := mp.isGenesisShardBlockAndFirstMeta(lastCrossNotarizedHeaderForShard.GetNonce()) - mp.hdrsForCurrBlock.hdrHashAndInfo[string(hash)] = &hdrInfo{hdr: lastCrossNotarizedHeaderForShard, usedInBlock: usedInBlock} + mp.hdrsForCurrBlock.hdrHashAndInfo[string(hash)] = &hdrInfo{ + hdr: lastCrossNotarizedHeaderForShard, + usedInBlock: usedInBlock, + } } return lastCrossNotarizedHeader, nil @@ -1750,7 +1835,10 @@ func (mp *metaProcessor) checkShardHeadersValidity(metaHdr *block.MetaBlock) (ma return nil, err } - usedShardHdrs := mp.sortHeadersForCurrentBlockByNonce(true) + usedShardHdrs, err := mp.sortHeadersForCurrentBlockByNonce(true) + if err != nil { + return nil, err + } highestNonceHdrs := make(map[uint32]data.HeaderHandler, len(usedShardHdrs)) if len(usedShardHdrs) == 0 { @@ -1797,6 +1885,11 @@ func (mp *metaProcessor) checkShardHeadersValidity(metaHdr *block.MetaBlock) (ma if shardData.DeveloperFees.Cmp(shardHdr.GetDeveloperFees()) != 0 { return nil, process.ErrDeveloperFeesDoNotMatch } + if mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, shardHdr.GetEpoch()) { + if shardData.Epoch != shardHdr.GetEpoch() { + return nil, process.ErrEpochMismatch + } + } mapMiniBlockHeadersInMetaBlock := make(map[string]struct{}) for _, shardMiniBlockHdr := range shardData.ShardMiniBlockHeaders { @@ -1835,7 +1928,10 @@ func (mp *metaProcessor) getFinalMiniBlockHeaders(miniBlockHeaderHandlers []data func (mp *metaProcessor) checkShardHeadersFinality( highestNonceHdrs map[uint32]data.HeaderHandler, ) error { - finalityAttestingShardHdrs := mp.sortHeadersForCurrentBlockByNonce(false) + finalityAttestingShardHdrs, err := mp.sortHeadersForCurrentBlockByNonce(false) + if err != nil { + return err + } var errFinal error @@ -1851,6 +1947,15 @@ func (mp *metaProcessor) checkShardHeadersFinality( continue } + isNotarizedBasedOnProofs, errCheckProof := mp.checkShardHeaderFinalityBasedOnProofs(lastVerifiedHdr, shardId) + if isNotarizedBasedOnProofs { + if errCheckProof != nil { + return errCheckProof + } + + continue + } + // verify if there are "K" block after current to make this one final nextBlocksVerified := uint32(0) for _, shardHdr := range finalityAttestingShardHdrs[shardId] { @@ -1867,12 +1972,21 @@ func (mp *metaProcessor) checkShardHeadersFinality( continue } + isNotarizedBasedOnProofs, errCheckProof = mp.checkShardHeaderFinalityBasedOnProofs(shardHdr, shardId) + if isNotarizedBasedOnProofs { + if errCheckProof != nil { + return errCheckProof + } + + break + } + lastVerifiedHdr = shardHdr nextBlocksVerified += 1 } } - if nextBlocksVerified < mp.shardBlockFinality { + if nextBlocksVerified < mp.shardBlockFinality && !isNotarizedBasedOnProofs { go mp.requestHandler.RequestShardHeaderByNonce(lastVerifiedHdr.GetShardID(), lastVerifiedHdr.GetNonce()) go mp.requestHandler.RequestShardHeaderByNonce(lastVerifiedHdr.GetShardID(), lastVerifiedHdr.GetNonce()+1) errFinal = process.ErrHeaderNotFinal @@ -1882,6 +1996,23 @@ func (mp *metaProcessor) checkShardHeadersFinality( return errFinal } +func (mp *metaProcessor) checkShardHeaderFinalityBasedOnProofs(shardHdr data.HeaderHandler, shardId uint32) (bool, error) { + if !mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, shardHdr.GetEpoch()) { + return false, nil + } + + headerHash, err := core.CalculateHash(mp.marshalizer, mp.hasher, shardHdr) + if err != nil { + return true, err + } + + if !mp.proofsPool.HasProof(shardId, headerHash) { + return true, process.ErrHeaderNotFinal + } + + return true, nil +} + // receivedShardHeader is a call back function which is called when a new header // is added in the headers pool func (mp *metaProcessor) receivedShardHeader(headerHandler data.HeaderHandler, shardHeaderHash []byte) { @@ -1904,6 +2035,8 @@ func (mp *metaProcessor) receivedShardHeader(headerHandler data.HeaderHandler, s hdrInfoForHash := mp.hdrsForCurrBlock.hdrHashAndInfo[string(shardHeaderHash)] headerInfoIsNotNil := hdrInfoForHash != nil headerIsMissing := headerInfoIsNotNil && check.IfNil(hdrInfoForHash.hdr) + hasProof := headerInfoIsNotNil && hdrInfoForHash.hasProof + hasProofRequested := headerInfoIsNotNil && hdrInfoForHash.hasProofRequested if headerIsMissing { hdrInfoForHash.hdr = shardHeader mp.hdrsForCurrBlock.missingHdrs-- @@ -1911,6 +2044,11 @@ func (mp *metaProcessor) receivedShardHeader(headerHandler data.HeaderHandler, s if shardHeader.GetNonce() > mp.hdrsForCurrBlock.highestHdrNonce[shardHeader.GetShardID()] { mp.hdrsForCurrBlock.highestHdrNonce[shardHeader.GetShardID()] = shardHeader.GetNonce() } + mp.updateLastNotarizedBlockForShard(shardHeader, shardHeaderHash) + + if !hasProof && !hasProofRequested { + mp.requestProofIfNeeded(shardHeaderHash, shardHeader) + } } if mp.hdrsForCurrBlock.missingHdrs == 0 { @@ -1922,9 +2060,10 @@ func (mp *metaProcessor) receivedShardHeader(headerHandler data.HeaderHandler, s missingShardHdrs := mp.hdrsForCurrBlock.missingHdrs missingFinalityAttestingShardHdrs := mp.hdrsForCurrBlock.missingFinalityAttestingHdrs + missingProofs := mp.hdrsForCurrBlock.missingProofs mp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() - allMissingShardHeadersReceived := missingShardHdrs == 0 && missingFinalityAttestingShardHdrs == 0 + allMissingShardHeadersReceived := missingShardHdrs == 0 && missingFinalityAttestingShardHdrs == 0 && missingProofs == 0 if allMissingShardHeadersReceived { mp.chRcvAllHdrs <- true } @@ -1937,34 +2076,58 @@ func (mp *metaProcessor) receivedShardHeader(headerHandler data.HeaderHandler, s // requestMissingFinalityAttestingShardHeaders requests the headers needed to accept the current selected headers for // processing the current block. It requests the shardBlockFinality headers greater than the highest shard header, -// for each shard, related to the block which should be processed +// for the given shard, related to the block which should be processed // this method should be called only under the mutex protection: hdrsForCurrBlock.mutHdrsForBlock func (mp *metaProcessor) requestMissingFinalityAttestingShardHeaders() uint32 { missingFinalityAttestingShardHeaders := uint32(0) for shardId := uint32(0); shardId < mp.shardCoordinator.NumberOfShards(); shardId++ { - missingFinalityAttestingHeaders := mp.requestMissingFinalityAttestingHeaders( - shardId, - mp.shardBlockFinality, - ) - + lastNotarizedShardHeader := mp.hdrsForCurrBlock.lastNotarizedShardHeaders[shardId] + missingFinalityAttestingHeaders := uint32(0) + if lastNotarizedShardHeader != nil && !lastNotarizedShardHeader.notarizedBasedOnProof { + missingFinalityAttestingHeaders = mp.requestMissingFinalityAttestingHeaders( + shardId, + mp.shardBlockFinality, + ) + } missingFinalityAttestingShardHeaders += missingFinalityAttestingHeaders } return missingFinalityAttestingShardHeaders } -func (mp *metaProcessor) requestShardHeaders(metaBlock *block.MetaBlock) (uint32, uint32) { +func (mp *metaProcessor) requestShardHeaders(metaBlock *block.MetaBlock) (uint32, uint32, uint32) { _ = core.EmptyChannel(mp.chRcvAllHdrs) if len(metaBlock.ShardInfo) == 0 { - return 0, 0 + return 0, 0, 0 } return mp.computeExistingAndRequestMissingShardHeaders(metaBlock) } -func (mp *metaProcessor) computeExistingAndRequestMissingShardHeaders(metaBlock *block.MetaBlock) (uint32, uint32) { +func (mp *metaProcessor) updateLastNotarizedBlockForShard(hdr data.ShardHeaderHandler, headerHash []byte) { + lastNotarizedForShard := mp.hdrsForCurrBlock.lastNotarizedShardHeaders[hdr.GetShardID()] + if lastNotarizedForShard == nil { + lastNotarizedForShard = &lastNotarizedHeaderInfo{header: hdr} + mp.hdrsForCurrBlock.lastNotarizedShardHeaders[hdr.GetShardID()] = lastNotarizedForShard + } + + if hdr.GetNonce() >= lastNotarizedForShard.header.GetNonce() { + notarizedBasedOnProofs := mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, hdr.GetEpoch()) + hasProof := false + if notarizedBasedOnProofs { + hasProof = mp.proofsPool.HasProof(hdr.GetShardID(), headerHash) + } + + lastNotarizedForShard.header = hdr + lastNotarizedForShard.hash = headerHash + lastNotarizedForShard.notarizedBasedOnProof = notarizedBasedOnProofs + lastNotarizedForShard.hasProof = hasProof + } +} + +func (mp *metaProcessor) computeExistingAndRequestMissingShardHeaders(metaBlock *block.MetaBlock) (uint32, uint32, uint32) { mp.hdrsForCurrBlock.mutHdrsForBlock.Lock() defer mp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() @@ -1976,7 +2139,7 @@ func (mp *metaProcessor) computeExistingAndRequestMissingShardHeaders(metaBlock continue } if !bytes.Equal(hash, shardData.HeaderHash) { - log.Warn("genesis hash missmatch", + log.Warn("genesis hash mismatch", "last notarized nonce", lastCrossNotarizedHeaderForShard.GetNonce(), "last notarized hash", hash, "genesis nonce", mp.genesisNonce, @@ -1992,9 +2155,12 @@ func (mp *metaProcessor) computeExistingAndRequestMissingShardHeaders(metaBlock if err != nil { mp.hdrsForCurrBlock.missingHdrs++ mp.hdrsForCurrBlock.hdrHashAndInfo[string(shardData.HeaderHash)] = &hdrInfo{ - hdr: nil, - usedInBlock: true, + hdr: nil, + usedInBlock: true, + hasProof: false, + hasProofRequested: false, } + go mp.requestHandler.RequestShardHeader(shardData.ShardID, shardData.HeaderHash) continue } @@ -2004,16 +2170,24 @@ func (mp *metaProcessor) computeExistingAndRequestMissingShardHeaders(metaBlock usedInBlock: true, } + mp.requestProofIfNeeded(shardData.HeaderHash, hdr) + + if common.IsEpochChangeBlockForFlagActivation(hdr, mp.enableEpochsHandler, common.AndromedaFlag) { + continue + } + if hdr.GetNonce() > mp.hdrsForCurrBlock.highestHdrNonce[shardData.ShardID] { mp.hdrsForCurrBlock.highestHdrNonce[shardData.ShardID] = hdr.GetNonce() } + + mp.updateLastNotarizedBlockForShard(hdr, shardData.HeaderHash) } if mp.hdrsForCurrBlock.missingHdrs == 0 { mp.hdrsForCurrBlock.missingFinalityAttestingHdrs = mp.requestMissingFinalityAttestingShardHeaders() } - return mp.hdrsForCurrBlock.missingHdrs, mp.hdrsForCurrBlock.missingFinalityAttestingHdrs + return mp.hdrsForCurrBlock.missingHdrs, mp.hdrsForCurrBlock.missingFinalityAttestingHdrs, mp.hdrsForCurrBlock.missingProofs } func (mp *metaProcessor) createShardInfo() ([]data.ShardDataHandler, error) { @@ -2030,6 +2204,13 @@ func (mp *metaProcessor) createShardInfo() ([]data.ShardDataHandler, error) { continue } + isBlockAfterAndromedaFlag := !check.IfNil(headerInfo.hdr) && + mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, headerInfo.hdr.GetEpoch()) && headerInfo.hdr.GetNonce() >= 1 + hasMissingShardHdrProof := isBlockAfterAndromedaFlag && !mp.proofsPool.HasProof(headerInfo.hdr.GetShardID(), []byte(hdrHash)) + if hasMissingShardHdrProof { + return nil, fmt.Errorf("%w for shard header with hash %s", process.ErrMissingHeaderProof, hex.EncodeToString([]byte(hdrHash))) + } + shardHdr, ok := headerInfo.hdr.(data.ShardHeaderHandler) if !ok { return nil, process.ErrWrongTypeAssertion @@ -2044,6 +2225,9 @@ func (mp *metaProcessor) createShardInfo() ([]data.ShardDataHandler, error) { shardData.Nonce = shardHdr.GetNonce() shardData.PrevRandSeed = shardHdr.GetPrevRandSeed() shardData.PubKeysBitmap = shardHdr.GetPubKeysBitmap() + if mp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, shardHdr.GetEpoch()) { + shardData.Epoch = shardHdr.GetEpoch() + } shardData.NumPendingMiniBlocks = uint32(len(mp.pendingMiniBlocksHandler.GetPendingMiniBlocks(shardData.ShardID))) header, _, err := mp.blockTracker.GetLastSelfNotarizedHeader(shardHdr.GetShardID()) if err != nil { @@ -2281,7 +2465,10 @@ func (mp *metaProcessor) prepareBlockHeaderInternalMapForValidatorProcessor() { } mp.hdrsForCurrBlock.mutHdrsForBlock.Lock() - mp.hdrsForCurrBlock.hdrHashAndInfo[string(currentBlockHeaderHash)] = &hdrInfo{false, currentBlockHeader} + mp.hdrsForCurrBlock.hdrHashAndInfo[string(currentBlockHeaderHash)] = &hdrInfo{ + usedInBlock: false, + hdr: currentBlockHeader, + } mp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() } @@ -2334,6 +2521,7 @@ func (mp *metaProcessor) CreateNewHeader(round uint64, nonce uint64) (data.Heade } mp.roundNotifier.CheckRound(header) + mp.epochNotifier.CheckEpoch(header) err = metaHeader.SetNonce(nonce) if err != nil { diff --git a/process/block/metablockRequest_test.go b/process/block/metablockRequest_test.go index 0718830a43c..2d9fdb5f89f 100644 --- a/process/block/metablockRequest_test.go +++ b/process/block/metablockRequest_test.go @@ -49,12 +49,13 @@ func TestMetaProcessor_computeExistingAndRequestMissingShardHeaders(t *testing.T require.NotNil(t, mp) headersForBlock := mp.GetHdrForBlock() - numMissing, numAttestationMissing := mp.ComputeExistingAndRequestMissingShardHeaders(metaBlock) + numMissing, numAttestationMissing, missingProofs := mp.ComputeExistingAndRequestMissingShardHeaders(metaBlock) time.Sleep(100 * time.Millisecond) require.Equal(t, uint32(2), numMissing) require.Equal(t, uint32(2), headersForBlock.GetMissingHdrs()) // before receiving all missing headers referenced in metaBlock, the number of missing attestations is not updated require.Equal(t, uint32(0), numAttestationMissing) + require.Equal(t, uint32(0), missingProofs) require.Equal(t, uint32(0), headersForBlock.GetMissingFinalityAttestingHdrs()) require.Len(t, headersForBlock.GetHdrHashAndInfo(), 2) require.Equal(t, uint32(0), numCallsMissingAttestation.Load()) @@ -85,13 +86,14 @@ func TestMetaProcessor_computeExistingAndRequestMissingShardHeaders(t *testing.T headersPool := mp.GetDataPool().Headers() // adding the existing header headersPool.AddHeader(td[0].referencedHeaderData.headerHash, td[0].referencedHeaderData.header) - numMissing, numAttestationMissing := mp.ComputeExistingAndRequestMissingShardHeaders(metaBlock) + numMissing, numAttestationMissing, missingProofs := mp.ComputeExistingAndRequestMissingShardHeaders(metaBlock) time.Sleep(100 * time.Millisecond) headersForBlock := mp.GetHdrForBlock() require.Equal(t, uint32(1), numMissing) require.Equal(t, uint32(1), headersForBlock.GetMissingHdrs()) // before receiving all missing headers referenced in metaBlock, the number of missing attestations is not updated require.Equal(t, uint32(0), numAttestationMissing) + require.Equal(t, uint32(0), missingProofs) require.Equal(t, uint32(0), headersForBlock.GetMissingFinalityAttestingHdrs()) require.Len(t, headersForBlock.GetHdrHashAndInfo(), 2) require.Equal(t, uint32(0), numCallsMissingAttestation.Load()) @@ -123,12 +125,13 @@ func TestMetaProcessor_computeExistingAndRequestMissingShardHeaders(t *testing.T // adding the existing headers headersPool.AddHeader(td[0].referencedHeaderData.headerHash, td[0].referencedHeaderData.header) headersPool.AddHeader(td[1].referencedHeaderData.headerHash, td[1].referencedHeaderData.header) - numMissing, numAttestationMissing := mp.ComputeExistingAndRequestMissingShardHeaders(metaBlock) + numMissing, numAttestationMissing, missingProofs := mp.ComputeExistingAndRequestMissingShardHeaders(metaBlock) time.Sleep(100 * time.Millisecond) headersForBlock := mp.GetHdrForBlock() require.Equal(t, uint32(0), numMissing) require.Equal(t, uint32(0), headersForBlock.GetMissingHdrs()) require.Equal(t, uint32(2), numAttestationMissing) + require.Equal(t, uint32(0), missingProofs) require.Equal(t, uint32(2), headersForBlock.GetMissingFinalityAttestingHdrs()) require.Len(t, headersForBlock.GetHdrHashAndInfo(), 2) require.Equal(t, uint32(2), numCallsMissingAttestation.Load()) @@ -161,12 +164,13 @@ func TestMetaProcessor_computeExistingAndRequestMissingShardHeaders(t *testing.T headersPool.AddHeader(td[0].referencedHeaderData.headerHash, td[0].referencedHeaderData.header) headersPool.AddHeader(td[1].referencedHeaderData.headerHash, td[1].referencedHeaderData.header) headersPool.AddHeader(td[0].attestationHeaderData.headerHash, td[0].attestationHeaderData.header) - numMissing, numAttestationMissing := mp.ComputeExistingAndRequestMissingShardHeaders(metaBlock) + numMissing, numAttestationMissing, missingProofs := mp.ComputeExistingAndRequestMissingShardHeaders(metaBlock) time.Sleep(100 * time.Millisecond) headersForBlock := mp.GetHdrForBlock() require.Equal(t, uint32(0), numMissing) require.Equal(t, uint32(0), headersForBlock.GetMissingHdrs()) require.Equal(t, uint32(1), numAttestationMissing) + require.Equal(t, uint32(0), missingProofs) require.Equal(t, uint32(1), headersForBlock.GetMissingFinalityAttestingHdrs()) require.Len(t, headersForBlock.GetHdrHashAndInfo(), 3) require.Equal(t, uint32(1), numCallsMissingAttestation.Load()) @@ -200,12 +204,13 @@ func TestMetaProcessor_computeExistingAndRequestMissingShardHeaders(t *testing.T headersPool.AddHeader(td[1].referencedHeaderData.headerHash, td[1].referencedHeaderData.header) headersPool.AddHeader(td[0].attestationHeaderData.headerHash, td[0].attestationHeaderData.header) headersPool.AddHeader(td[1].attestationHeaderData.headerHash, td[1].attestationHeaderData.header) - numMissing, numAttestationMissing := mp.ComputeExistingAndRequestMissingShardHeaders(metaBlock) + numMissing, numAttestationMissing, missingProofs := mp.ComputeExistingAndRequestMissingShardHeaders(metaBlock) time.Sleep(100 * time.Millisecond) headersForBlock := mp.GetHdrForBlock() require.Equal(t, uint32(0), numMissing) require.Equal(t, uint32(0), headersForBlock.GetMissingHdrs()) require.Equal(t, uint32(0), numAttestationMissing) + require.Equal(t, uint32(0), missingProofs) require.Equal(t, uint32(0), headersForBlock.GetMissingFinalityAttestingHdrs()) require.Len(t, headersForBlock.GetHdrHashAndInfo(), 4) require.Equal(t, uint32(0), numCallsMissingAttestation.Load()) diff --git a/process/block/metablock_test.go b/process/block/metablock_test.go index c78f2c5b039..1ff16022b10 100644 --- a/process/block/metablock_test.go +++ b/process/block/metablock_test.go @@ -13,7 +13,12 @@ import ( "github.com/multiversx/mx-chain-core-go/core/atomic" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/graceperiod" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" "github.com/multiversx/mx-chain-go/process" @@ -36,8 +41,6 @@ import ( stateMock "github.com/multiversx/mx-chain-go/testscommon/state" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockComponentHolders() ( @@ -47,18 +50,19 @@ func createMockComponentHolders() ( *mock.StatusComponentsMock, ) { mdp := initDataPool([]byte("tx_hash")) - + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) coreComponents := &mock.CoreComponentsMock{ - IntMarsh: &mock.MarshalizerMock{}, - Hash: &mock.HasherStub{}, - UInt64ByteSliceConv: &mock.Uint64ByteSliceConverterMock{}, - StatusField: &statusHandlerMock.AppStatusHandlerStub{}, - RoundField: &mock.RoundHandlerMock{RoundTimeDuration: time.Second}, - ProcessStatusHandlerField: &testscommon.ProcessStatusHandlerStub{}, - EpochNotifierField: &epochNotifier.EpochNotifierStub{}, - EnableEpochsHandlerField: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), - RoundNotifierField: &epochNotifier.RoundNotifierStub{}, - EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, + IntMarsh: &mock.MarshalizerMock{}, + Hash: &mock.HasherStub{}, + UInt64ByteSliceConv: &mock.Uint64ByteSliceConverterMock{}, + StatusField: &statusHandlerMock.AppStatusHandlerStub{}, + RoundField: &mock.RoundHandlerMock{RoundTimeDuration: time.Second}, + ProcessStatusHandlerField: &testscommon.ProcessStatusHandlerStub{}, + EpochNotifierField: &epochNotifier.EpochNotifierStub{}, + EnableEpochsHandlerField: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), + RoundNotifierField: &epochNotifier.RoundNotifierStub{}, + EnableRoundsHandlerField: &testscommon.EnableRoundsHandlerStub{}, + EpochChangeGracePeriodHandlerField: gracePeriod, } dataComponents := &mock.DataComponentsMock{ @@ -91,8 +95,9 @@ func createMockMetaArguments( ) blproc.ArgMetaProcessor { argsHeaderValidator := blproc.ArgsHeaderValidator{ - Hasher: &mock.HasherStub{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &mock.HasherStub{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), } headerValidator, _ := blproc.NewHeaderValidator(argsHeaderValidator) @@ -818,10 +823,28 @@ func TestMetaProcessor_RequestFinalMissingHeaderShouldPass(t *testing.T) { mp, _ := blproc.NewMetaProcessor(arguments) mp.AddHdrHashToRequestedList(&block.Header{}, []byte("header_hash")) mp.SetHighestHdrNonceForCurrentBlock(0, 1) + mp.SetLastNotarizedHeaderForShard(0, &blproc.LastNotarizedHeaderInfo{ + Header: &block.Header{Nonce: 0, ShardID: 0}, + Hash: []byte("header hash"), + NotarizedBasedOnProof: false, + HasProof: false, + }) mp.SetHighestHdrNonceForCurrentBlock(1, 2) + mp.SetLastNotarizedHeaderForShard(1, &blproc.LastNotarizedHeaderInfo{ + Header: &block.Header{Nonce: 2, ShardID: 1}, + Hash: []byte("header hash"), + NotarizedBasedOnProof: false, + HasProof: false, + }) mp.SetHighestHdrNonceForCurrentBlock(2, 3) + mp.SetLastNotarizedHeaderForShard(2, &blproc.LastNotarizedHeaderInfo{ + Header: &block.Header{Nonce: 3, ShardID: 2}, + Hash: []byte("header hash"), + NotarizedBasedOnProof: false, + HasProof: false, + }) res := mp.RequestMissingFinalityAttestingShardHeaders() - assert.Equal(t, res, uint32(3)) + assert.Equal(t, uint32(3), res) } // ------- CommitBlock @@ -1121,7 +1144,7 @@ func TestBlockProc_RequestTransactionFromNetwork(t *testing.T) { } header := createMetaBlockHeader() - hdrsRequested, _ := mp.RequestBlockHeaders(header) + hdrsRequested, _, _ := mp.RequestBlockHeaders(header) assert.Equal(t, uint32(1), hdrsRequested) } @@ -1956,8 +1979,9 @@ func TestMetaProcessor_CheckShardHeadersValidity(t *testing.T) { arguments.BlockTracker = mock.NewBlockTrackerMock(bootstrapComponents.ShardCoordinator(), startHeaders) argsHeaderValidator := blproc.ArgsHeaderValidator{ - Hasher: coreComponents.Hash, - Marshalizer: coreComponents.InternalMarshalizer(), + Hasher: coreComponents.Hash, + Marshalizer: coreComponents.InternalMarshalizer(), + EnableEpochsHandler: coreComponents.EnableEpochsHandler(), } arguments.HeaderValidator, _ = blproc.NewHeaderValidator(argsHeaderValidator) diff --git a/process/block/metrics.go b/process/block/metrics.go index ce29ddb23f8..5e52e2b3980 100644 --- a/process/block/metrics.go +++ b/process/block/metrics.go @@ -12,11 +12,12 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" outportcore "github.com/multiversx/mx-chain-core-go/data/outport" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/outport" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" - logger "github.com/multiversx/mx-chain-logger-go" ) const leaderIndex = 0 @@ -129,7 +130,7 @@ func incrementMetricCountConsensusAcceptedBlocks( appStatusHandler core.AppStatusHandler, managedPeersHolder common.ManagedPeersHolder, ) { - pubKeys, err := nodesCoordinator.GetConsensusValidatorsPublicKeys( + _, pubKeys, err := nodesCoordinator.GetConsensusValidatorsPublicKeys( header.GetPrevRandSeed(), header.GetRound(), header.GetShardID(), @@ -162,6 +163,7 @@ func indexRoundInfo( header data.HeaderHandler, lastHeader data.HeaderHandler, signersIndexes []uint64, + enableEpochsHandler common.EnableEpochsHandler, ) { roundInfo := &outportcore.RoundInfo{ Round: header.GetRound(), @@ -184,13 +186,9 @@ func indexRoundInfo( roundsInfo := make([]*outportcore.RoundInfo, 0) roundsInfo = append(roundsInfo, roundInfo) for i := lastBlockRound + 1; i < currentBlockRound; i++ { - publicKeys, err := nodesCoordinator.GetConsensusValidatorsPublicKeys(lastHeader.GetRandSeed(), i, shardId, lastHeader.GetEpoch()) - if err != nil { - continue - } - signersIndexes, err = nodesCoordinator.GetValidatorsIndexes(publicKeys, lastHeader.GetEpoch()) - if err != nil { - log.Error(err.Error(), "round", i) + var ok bool + signersIndexes, ok = getSignersIndices(header, enableEpochsHandler, lastHeader, i, nodesCoordinator) + if !ok { continue } @@ -209,6 +207,33 @@ func indexRoundInfo( outportHandler.SaveRoundsInfo(&outportcore.RoundsInfo{ShardID: shardId, RoundsInfo: roundsInfo}) } +func getSignersIndices( + header data.HeaderHandler, + enableEpochsHandler common.EnableEpochsHandler, + lastHeader data.HeaderHandler, + round uint64, + nodesCoordinator nodesCoordinator.NodesCoordinator, +) ([]uint64, bool) { + // if AndromedaFlag is active and all validators are in consensus group - signer indices no longer needed + if common.IsFlagEnabledAfterEpochsStartBlock(header, enableEpochsHandler, common.AndromedaFlag) { + return make([]uint64, 0), true + } + + _, publicKeys, err := nodesCoordinator.GetConsensusValidatorsPublicKeys(lastHeader.GetRandSeed(), round, header.GetShardID(), lastHeader.GetEpoch()) + if err != nil { + log.Error("getSignersIndices: cannot get validators public keys", "error", err.Error(), "round", round) + return nil, false + } + + signersIndexes, err := nodesCoordinator.GetValidatorsIndexes(publicKeys, lastHeader.GetEpoch()) + if err != nil { + log.Error("getSignersIndices: cannot get signers indices", "error", err.Error(), "round", round) + return nil, false + } + + return signersIndexes, true +} + func indexValidatorsRating( outportHandler outport.OutportHandler, valStatProc process.ValidatorStatisticsProcessor, diff --git a/process/block/metrics_test.go b/process/block/metrics_test.go index 2457bd67ac1..eff2950f371 100644 --- a/process/block/metrics_test.go +++ b/process/block/metrics_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" - "github.com/stretchr/testify/assert" ) func TestMetrics_CalculateRoundDuration(t *testing.T) { @@ -32,8 +33,8 @@ func TestMetrics_IncrementMetricCountConsensusAcceptedBlocks(t *testing.T) { t.Parallel() nodesCoord := &shardingMocks.NodesCoordinatorMock{ - GetValidatorsPublicKeysCalled: func(_ []byte, _ uint64, _ uint32, _ uint32) ([]string, error) { - return nil, expectedErr + GetValidatorsPublicKeysCalled: func(_ []byte, _ uint64, _ uint32, _ uint32) (string, []string, error) { + return "", nil, expectedErr }, } statusHandler := &statusHandlerMock.AppStatusHandlerStub{ @@ -54,9 +55,10 @@ func TestMetrics_IncrementMetricCountConsensusAcceptedBlocks(t *testing.T) { GetOwnPublicKeyCalled: func() []byte { return []byte(mainKey) }, - GetValidatorsPublicKeysCalled: func(_ []byte, _ uint64, _ uint32, _ uint32) ([]string, error) { - return []string{ - "some leader", + GetValidatorsPublicKeysCalled: func(_ []byte, _ uint64, _ uint32, _ uint32) (string, []string, error) { + leader := "some leader" + return leader, []string{ + leader, mainKey, managedKeyInConsensus, "some other key", diff --git a/process/block/poolsCleaner/miniBlocksPoolsCleaner_test.go b/process/block/poolsCleaner/miniBlocksPoolsCleaner_test.go index b590009bdf7..ba16c9dadbb 100644 --- a/process/block/poolsCleaner/miniBlocksPoolsCleaner_test.go +++ b/process/block/poolsCleaner/miniBlocksPoolsCleaner_test.go @@ -6,9 +6,11 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) @@ -19,7 +21,7 @@ func createMockArgMiniBlocksPoolsCleaner() ArgMiniBlocksPoolsCleaner { ShardCoordinator: &mock.CoordinatorStub{}, MaxRoundsToKeepUnprocessedData: 1, }, - MiniblocksPool: testscommon.NewCacherStub(), + MiniblocksPool: cache.NewCacherStub(), } } @@ -103,7 +105,7 @@ func TestCleanMiniblocksPoolsIfNeeded_MiniblockNotInPoolShouldBeRemovedFromMap(t t.Parallel() args := createMockArgMiniBlocksPoolsCleaner() - args.MiniblocksPool = &testscommon.CacherStub{ + args.MiniblocksPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -122,7 +124,7 @@ func TestCleanMiniblocksPoolsIfNeeded_RoundDiffTooSmallMiniblockShouldRemainInMa t.Parallel() args := createMockArgMiniBlocksPoolsCleaner() - args.MiniblocksPool = &testscommon.CacherStub{ + args.MiniblocksPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -142,7 +144,7 @@ func TestCleanMiniblocksPoolsIfNeeded_MbShouldBeRemovedFromPoolAndMap(t *testing args := createMockArgMiniBlocksPoolsCleaner() called := false - args.MiniblocksPool = &testscommon.CacherStub{ + args.MiniblocksPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, diff --git a/process/block/poolsCleaner/txsPoolsCleaner_test.go b/process/block/poolsCleaner/txsPoolsCleaner_test.go index 125f44e1870..cbcab2aae85 100644 --- a/process/block/poolsCleaner/txsPoolsCleaner_test.go +++ b/process/block/poolsCleaner/txsPoolsCleaner_test.go @@ -6,14 +6,16 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/txcache" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" - "github.com/stretchr/testify/assert" ) func createMockArgTxsPoolsCleaner() ArgTxsPoolsCleaner { @@ -174,7 +176,7 @@ func TestReceivedBlockTx_ShouldBeAddedInMapTxsRounds(t *testing.T) { TransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return testscommon.NewCacherMock() + return cache.NewCacherMock() }, } }, @@ -199,7 +201,7 @@ func TestReceivedRewardTx_ShouldBeAddedInMapTxsRounds(t *testing.T) { RewardTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return testscommon.NewCacherMock() + return cache.NewCacherMock() }, } }, @@ -223,7 +225,7 @@ func TestReceivedUnsignedTx_ShouldBeAddedInMapTxsRounds(t *testing.T) { UnsignedTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return testscommon.NewCacherMock() + return cache.NewCacherMock() }, } }, @@ -252,7 +254,7 @@ func TestCleanTxsPoolsIfNeeded_CannotFindTxInPoolShouldBeRemovedFromMap(t *testi UnsignedTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return testscommon.NewCacherMock() + return cache.NewCacherMock() }, } }, @@ -283,7 +285,7 @@ func TestCleanTxsPoolsIfNeeded_RoundDiffTooSmallShouldNotBeRemoved(t *testing.T) UnsignedTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -323,7 +325,7 @@ func TestCleanTxsPoolsIfNeeded_RoundDiffTooBigShouldBeRemoved(t *testing.T) { UnsignedTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, diff --git a/process/block/preprocess/rewardTxPreProcessor_test.go b/process/block/preprocess/rewardTxPreProcessor_test.go index ad0d0952569..836a85d8652 100644 --- a/process/block/preprocess/rewardTxPreProcessor_test.go +++ b/process/block/preprocess/rewardTxPreProcessor_test.go @@ -9,17 +9,19 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/rewardTx" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/common" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" ) const testTxHash = "tx1_hash" @@ -904,7 +906,7 @@ func TestRewardTxPreprocessor_RestoreBlockDataIntoPools(t *testing.T) { blockBody := &block.Body{} blockBody.MiniBlocks = append(blockBody.MiniBlocks, &mb1) - miniBlockPool := testscommon.NewCacherMock() + miniBlockPool := cache.NewCacherMock() numRestoredTxs, err := rtp.RestoreBlockDataIntoPools(blockBody, miniBlockPool) assert.Equal(t, 1, numRestoredTxs) diff --git a/process/block/preprocess/smartContractResults_test.go b/process/block/preprocess/smartContractResults_test.go index 6f56571c7d7..37a03255c66 100644 --- a/process/block/preprocess/smartContractResults_test.go +++ b/process/block/preprocess/smartContractResults_test.go @@ -13,20 +13,22 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/smartContractResult" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" commonTests "github.com/multiversx/mx-chain-go/testscommon/common" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" ) func haveTime() time.Duration { @@ -691,7 +693,7 @@ func TestScrsPreprocessor_ReceivedTransactionShouldEraseRequested(t *testing.T) shardedDataStub := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return &smartContractResult.SmartContractResult{}, true }, @@ -1430,7 +1432,7 @@ func TestScrsPreprocessor_ProcessMiniBlock(t *testing.T) { tdp.TransactionsCalled = func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx1_hash")) { return &smartContractResult.SmartContractResult{Nonce: 10}, true @@ -1589,7 +1591,7 @@ func TestScrsPreprocessor_RestoreBlockDataIntoPools(t *testing.T) { } body.MiniBlocks = append(body.MiniBlocks, &miniblock) - miniblockPool := testscommon.NewCacherMock() + miniblockPool := cache.NewCacherMock() scrRestored, err := scr.RestoreBlockDataIntoPools(body, miniblockPool) assert.Equal(t, scrRestored, 1) diff --git a/process/block/preprocess/transactions_test.go b/process/block/preprocess/transactions_test.go index 193e08de309..68dd2d7e709 100644 --- a/process/block/preprocess/transactions_test.go +++ b/process/block/preprocess/transactions_test.go @@ -21,6 +21,10 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing/blake2b" "github.com/multiversx/mx-chain-core-go/hashing/sha256" "github.com/multiversx/mx-chain-core-go/marshal" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" @@ -29,6 +33,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/txcache" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" commonMocks "github.com/multiversx/mx-chain-go/testscommon/common" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" @@ -39,9 +44,6 @@ import ( stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/vm" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const MaxGasLimitPerBlock = uint64(100000) @@ -78,7 +80,7 @@ func feeHandlerMock() *economicsmocks.EconomicsHandlerMock { func shardedDataCacherNotifier() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx1_hash")) { return &smartContractResult.SmartContractResult{Nonce: 10}, true @@ -123,7 +125,7 @@ func initDataPool() *dataRetrieverMock.PoolsHolderStub { RewardTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx1_hash")) { return &rewardTx.RewardTx{Value: big.NewInt(100)}, true @@ -155,7 +157,7 @@ func initDataPool() *dataRetrieverMock.PoolsHolderStub { } }, MetaBlocksCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx1_hash")) { return &transaction.Transaction{Nonce: 10}, true @@ -178,7 +180,7 @@ func initDataPool() *dataRetrieverMock.PoolsHolderStub { } }, MiniBlocksCalled: func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -522,7 +524,7 @@ func TestTransactionPreprocessor_ReceivedTransactionShouldEraseRequested(t *test shardedDataStub := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return &transaction.Transaction{}, true }, @@ -1248,7 +1250,7 @@ func TestTransactionsPreprocessor_ProcessMiniBlockShouldWork(t *testing.T) { TransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx_hash1")) { return &transaction.Transaction{Nonce: 10}, true @@ -1334,7 +1336,7 @@ func TestTransactionsPreprocessor_ProcessMiniBlockShouldErrMaxGasLimitUsedForDes TransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx_hash1")) { return &transaction.Transaction{}, true @@ -2046,7 +2048,7 @@ func TestTransactions_RestoreBlockDataIntoPools(t *testing.T) { args.Store = genericMocks.NewChainStorerMock(0) txs, _ := NewTransactionPreprocessor(args) - mbPool := testscommon.NewCacherMock() + mbPool := cache.NewCacherMock() body, allTxs := createMockBlockBody() storer, _ := args.Store.GetStorer(dataRetriever.TransactionUnit) diff --git a/process/block/preprocess/validatorInfoPreProcessor_test.go b/process/block/preprocess/validatorInfoPreProcessor_test.go index 059c6c3d0b1..59cf03baa6c 100644 --- a/process/block/preprocess/validatorInfoPreProcessor_test.go +++ b/process/block/preprocess/validatorInfoPreProcessor_test.go @@ -8,17 +8,19 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/rewardTx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestNewValidatorInfoPreprocessor_NilHasherShouldErr(t *testing.T) { @@ -289,7 +291,7 @@ func TestNewValidatorInfoPreprocessor_RestorePeerBlockIntoPools(t *testing.T) { blockBody := &block.Body{} blockBody.MiniBlocks = append(blockBody.MiniBlocks, &mb1) - miniBlockPool := testscommon.NewCacherMock() + miniBlockPool := cache.NewCacherMock() marshalizedMb, _ := marshalizer.Marshal(mb1) mbHash := hasher.Compute(string(marshalizedMb)) @@ -334,7 +336,7 @@ func TestNewValidatorInfoPreprocessor_RestoreOtherBlockTypeIntoPoolsShouldNotRes blockBody := &block.Body{} blockBody.MiniBlocks = append(blockBody.MiniBlocks, &mb1) - miniBlockPool := testscommon.NewCacherMock() + miniBlockPool := cache.NewCacherMock() marshalizedMb, _ := marshalizer.Marshal(mb1) mbHash := hasher.Compute(string(marshalizedMb)) @@ -382,7 +384,7 @@ func TestNewValidatorInfoPreprocessor_RemovePeerBlockFromPool(t *testing.T) { blockBody := &block.Body{} blockBody.MiniBlocks = append(blockBody.MiniBlocks, &mb1) - miniBlockPool := testscommon.NewCacherMock() + miniBlockPool := cache.NewCacherMock() miniBlockPool.Put(mbHash, marshalizedMb, len(marshalizedMb)) foundMb, ok := miniBlockPool.Get(mbHash) @@ -427,7 +429,7 @@ func TestNewValidatorInfoPreprocessor_RemoveOtherBlockTypeFromPoolShouldNotRemov blockBody := &block.Body{} blockBody.MiniBlocks = append(blockBody.MiniBlocks, &mb1) - miniBlockPool := testscommon.NewCacherMock() + miniBlockPool := cache.NewCacherMock() miniBlockPool.Put(mbHash, marshalizedMb, len(marshalizedMb)) foundMb, ok := miniBlockPool.Get(mbHash) diff --git a/process/block/shardblock.go b/process/block/shardblock.go index 2953b0fc7de..f5ffdc411db 100644 --- a/process/block/shardblock.go +++ b/process/block/shardblock.go @@ -2,6 +2,8 @@ package block import ( "bytes" + "encoding/hex" + "errors" "fmt" "math/big" "time" @@ -11,6 +13,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/headerVersionData" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/holders" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -20,7 +24,6 @@ import ( "github.com/multiversx/mx-chain-go/process/block/helpers" "github.com/multiversx/mx-chain-go/process/block/processedMb" "github.com/multiversx/mx-chain-go/state" - logger "github.com/multiversx/mx-chain-logger-go" ) var _ process.BlockProcessor = (*shardProcessor)(nil) @@ -48,7 +51,6 @@ type createAndProcessMiniBlocksDestMeInfo struct { type shardProcessor struct { *baseProcessor metaBlockFinality uint32 - chRcvAllMetaHdrs chan bool } // NewShardProcessor creates a new shardProcessor object @@ -108,6 +110,7 @@ func NewShardProcessor(arguments ArgShardProcessor) (*shardProcessor, error) { enableEpochsHandler: arguments.CoreComponents.EnableEpochsHandler(), roundNotifier: arguments.CoreComponents.RoundNotifier(), enableRoundsHandler: arguments.CoreComponents.EnableRoundsHandler(), + epochChangeGracePeriodHandler: arguments.CoreComponents.EpochChangeGracePeriodHandler(), vmContainerFactory: arguments.VMContainersFactory, vmContainer: arguments.VmContainer, processDataTriesOnCommitEpoch: arguments.Config.Debug.EpochStart.ProcessDataTrieOnCommitEpoch, @@ -124,6 +127,7 @@ func NewShardProcessor(arguments ArgShardProcessor) (*shardProcessor, error) { managedPeersHolder: arguments.ManagedPeersHolder, sentSignaturesTracker: arguments.SentSignaturesTracker, extraDelayRequestBlockInfo: time.Duration(arguments.Config.EpochStartConfig.ExtraDelayForRequestBlockInfoInMilliseconds) * time.Millisecond, + proofsPool: arguments.DataComponents.Datapool().Proofs(), } sp := shardProcessor{ @@ -144,13 +148,15 @@ func NewShardProcessor(arguments ArgShardProcessor) (*shardProcessor, error) { sp.requestBlockBodyHandler = &sp sp.blockProcessor = &sp - sp.chRcvAllMetaHdrs = make(chan bool) + sp.chRcvAllHdrs = make(chan bool) sp.hdrsForCurrBlock = newHdrForBlock() headersPool := sp.dataPool.Headers() headersPool.RegisterHandler(sp.receivedMetaBlock) + sp.proofsPool.RegisterHandler(sp.checkReceivedProofIfAttestingIsNeeded) + sp.metaBlockFinality = process.BlockFinality return &sp, nil @@ -171,7 +177,7 @@ func (sp *shardProcessor) ProcessBlock( err := sp.checkBlockValidity(headerHandler, bodyHandler) if err != nil { - if err == process.ErrBlockHashDoesNotMatch { + if errors.Is(err, process.ErrBlockHashDoesNotMatch) { log.Debug("requested missing shard header", "hash", headerHandler.GetPrevHash(), "for shard", headerHandler.GetShardID(), @@ -234,7 +240,7 @@ func (sp *shardProcessor) ProcessBlock( } sp.txCoordinator.RequestBlockTransactions(body) - requestedMetaHdrs, requestedFinalityAttestingMetaHdrs := sp.requestMetaHeaders(header) + requestedMetaHdrs, requestedFinalityAttestingMetaHdrs, requestedProofs := sp.requestMetaHeaders(header) if haveTime() < 0 { return process.ErrTimeIsOut @@ -245,7 +251,7 @@ func (sp *shardProcessor) ProcessBlock( return err } - haveMissingMetaHeaders := requestedMetaHdrs > 0 || requestedFinalityAttestingMetaHdrs > 0 + haveMissingMetaHeaders := requestedMetaHdrs > 0 || requestedFinalityAttestingMetaHdrs > 0 || requestedProofs > 0 if haveMissingMetaHeaders { if requestedMetaHdrs > 0 { log.Debug("requested missing meta headers", @@ -257,11 +263,17 @@ func (sp *shardProcessor) ProcessBlock( "num finality meta headers", requestedFinalityAttestingMetaHdrs, ) } + if requestedProofs > 0 { + log.Debug("requested missing meta header proofs", + "num proofs", requestedProofs, + ) + } err = sp.waitForMetaHdrHashes(haveTime()) sp.hdrsForCurrBlock.mutHdrsForBlock.RLock() missingMetaHdrs := sp.hdrsForCurrBlock.missingHdrs + missingProofs := sp.hdrsForCurrBlock.missingProofs sp.hdrsForCurrBlock.mutHdrsForBlock.RUnlock() sp.hdrsForCurrBlock.resetMissingHdrs() @@ -272,6 +284,12 @@ func (sp *shardProcessor) ProcessBlock( ) } + if requestedProofs > 0 { + log.Debug("received missing meta header proofs", + "num proofs", requestedProofs-missingProofs, + ) + } + if err != nil { return err } @@ -374,8 +392,12 @@ func (sp *shardProcessor) requestEpochStartInfo(header data.ShardHeaderHandler, return nil } + // force header cleanup from pool so that the receiving of the epoch start meta block will reach the trigger + sp.dataPool.Headers().RemoveHeaderByHash(header.GetEpochStartMetaHash()) go sp.requestHandler.RequestMetaHeader(header.GetEpochStartMetaHash()) + sp.requestEpochStartProofIfNeeded(header.GetEpochStartMetaHash(), header.GetEpoch()) + headersPool := sp.dataPool.Headers() for { time.Sleep(timeBetweenCheckForEpochStart) @@ -390,13 +412,22 @@ func (sp *shardProcessor) requestEpochStartInfo(header data.ShardHeaderHandler, epochStartMetaHdr, err := headersPool.GetHeaderByHash(header.GetEpochStartMetaHash()) if err != nil { go sp.requestHandler.RequestMetaHeader(header.GetEpochStartMetaHash()) + sp.requestEpochStartProofIfNeeded(header.GetEpochStartMetaHash(), header.GetEpoch()) continue } - _, _, err = headersPool.GetHeadersByNonceAndShardId(epochStartMetaHdr.GetNonce()+1, core.MetachainShardId) - if err != nil { - go sp.requestHandler.RequestMetaHeaderByNonce(epochStartMetaHdr.GetNonce() + 1) - continue + shouldConsiderProofsForNotarization := sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, epochStartMetaHdr.GetEpoch()) + if !shouldConsiderProofsForNotarization { + _, _, err = headersPool.GetHeadersByNonceAndShardId(epochStartMetaHdr.GetNonce()+1, core.MetachainShardId) + if err != nil { + go sp.requestHandler.RequestMetaHeaderByNonce(epochStartMetaHdr.GetNonce() + 1) + continue + } + } else { + hasProof := sp.requestEpochStartProofIfNeeded(header.GetEpochStartMetaHash(), header.GetEpoch()) + if !hasProof { + continue + } } return nil @@ -405,6 +436,21 @@ func (sp *shardProcessor) requestEpochStartInfo(header data.ShardHeaderHandler, return process.ErrTimeIsOut } +func (sp *shardProcessor) requestEpochStartProofIfNeeded(hash []byte, epoch uint32) bool { + if !sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, epoch) { + return true // no proof needed + } + + hasProof := sp.proofsPool.HasProof(core.MetachainShardId, hash) + if hasProof { + return true + } + + go sp.requestHandler.RequestEquivalentProofByHash(core.MetachainShardId, hash) + + return false +} + // RevertStateToBlock recreates the state tries to the root hashes indicated by the provided root hash and header func (sp *shardProcessor) RevertStateToBlock(header data.HeaderHandler, rootHash []byte) error { rootHashHolder := holders.NewDefaultRootHashesHolder(rootHash) @@ -469,10 +515,20 @@ func (sp *shardProcessor) checkEpochCorrectness( process.ErrEpochDoesNotMatch, header.GetEpoch(), sp.epochStartTrigger.MetaEpoch()) } + epochChangeConfirmed := sp.epochStartTrigger.EpochStartRound() < sp.epochStartTrigger.EpochFinalityAttestingRound() + if sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + epochChangeConfirmed = sp.epochStartTrigger.EpochStartRound() <= sp.epochStartTrigger.EpochFinalityAttestingRound() + } + + gracePeriod, err := sp.epochChangeGracePeriodHandler.GetGracePeriodForEpoch(header.GetEpoch()) + if err != nil { + return fmt.Errorf("%w could not get grace period for epoch %d", err, header.GetEpoch()) + } + isOldEpochAndShouldBeNew := sp.epochStartTrigger.IsEpochStart() && - header.GetRound() > sp.epochStartTrigger.EpochFinalityAttestingRound()+process.EpochChangeGracePeriod && + header.GetRound() > sp.epochStartTrigger.EpochFinalityAttestingRound()+uint64(gracePeriod) && header.GetEpoch() < sp.epochStartTrigger.MetaEpoch() && - sp.epochStartTrigger.EpochStartRound() < sp.epochStartTrigger.EpochFinalityAttestingRound() + epochChangeConfirmed if isOldEpochAndShouldBeNew { return fmt.Errorf("%w proposed header with epoch %d should be in epoch %d", process.ErrEpochDoesNotMatch, header.GetEpoch(), sp.epochStartTrigger.MetaEpoch()) @@ -527,7 +583,10 @@ func (sp *shardProcessor) checkMetaHeadersValidityAndFinality() error { } log.Trace("checkMetaHeadersValidityAndFinality", "lastCrossNotarizedHeader nonce", lastCrossNotarizedHeader.GetNonce()) - usedMetaHdrs := sp.sortHeadersForCurrentBlockByNonce(true) + usedMetaHdrs, err := sp.sortHeadersForCurrentBlockByNonce(true) + if err != nil { + return err + } if len(usedMetaHdrs[core.MetachainShardId]) == 0 { return nil } @@ -550,13 +609,33 @@ func (sp *shardProcessor) checkMetaHeadersValidityAndFinality() error { return nil } +func (sp *shardProcessor) checkHeaderHasProof(header data.HeaderHandler) error { + hash, errHash := sp.getHeaderHash(header) + if errHash != nil { + return errHash + } + + if !sp.proofsPool.HasProof(header.GetShardID(), hash) { + return fmt.Errorf("%w, missing proof for header %s", process.ErrHeaderNotFinal, hex.EncodeToString(hash)) + } + + return nil +} + // check if shard headers are final by checking if newer headers were constructed upon them func (sp *shardProcessor) checkMetaHdrFinality(header data.HeaderHandler) error { if check.IfNil(header) { return process.ErrNilBlockHeader } - finalityAttestingMetaHdrs := sp.sortHeadersForCurrentBlockByNonce(false) + if sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + return sp.checkHeaderHasProof(header) + } + + finalityAttestingMetaHdrs, err := sp.sortHeadersForCurrentBlockByNonce(false) + if err != nil { + return err + } lastVerifiedHdr := header // verify if there are "K" block after current to make this one final @@ -575,6 +654,10 @@ func (sp *shardProcessor) checkMetaHdrFinality(header data.HeaderHandler) error continue } + if sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, metaHdr.GetEpoch()) { + return sp.checkHeaderHasProof(metaHdr) + } + lastVerifiedHdr = metaHdr nextBlocksVerified += 1 } @@ -632,7 +715,14 @@ func (sp *shardProcessor) indexBlockIfNeeded( log.Debug("indexed block", "hash", headerHash, "nonce", header.GetNonce(), "round", header.GetRound()) shardID := sp.shardCoordinator.SelfId() - indexRoundInfo(sp.outportHandler, sp.nodesCoordinator, shardID, header, lastBlockHeader, argSaveBlock.SignersIndexes) + indexRoundInfo( + sp.outportHandler, + sp.nodesCoordinator, + shardID, header, + lastBlockHeader, + argSaveBlock.SignersIndexes, + sp.enableEpochsHandler, + ) } // RestoreBlockIntoPools restores the TxBlock and MetaBlock into associated pools @@ -1026,7 +1116,12 @@ func (sp *shardProcessor) CommitBlock( sp.lastRestartNonce = header.GetNonce() } - sp.updateState(selfNotarizedHeaders, header) + finalHeaderHash := headerHash + if !common.IsFlagEnabledAfterEpochsStartBlock(header, sp.enableEpochsHandler, common.AndromedaFlag) { + finalHeaderHash = currentHeaderHash + } + + sp.updateState(selfNotarizedHeaders, header, finalHeaderHash) highestFinalBlockNonce := sp.forkDetector.GetHighestFinalBlockNonce() log.Debug("highest final shard block", @@ -1163,7 +1258,7 @@ func (sp *shardProcessor) displayPoolsInfo() { sp.displayMiniBlocksPool() } -func (sp *shardProcessor) updateState(headers []data.HeaderHandler, currentHeader data.ShardHeaderHandler) { +func (sp *shardProcessor) updateState(headers []data.HeaderHandler, currentHeader data.ShardHeaderHandler, currentHeaderHash []byte) { sp.snapShotEpochStartFromMeta(currentHeader) for _, header := range headers { @@ -1234,15 +1329,36 @@ func (sp *shardProcessor) updateState(headers []data.HeaderHandler, currentHeade sp.accountsDB[state.UserAccountsState], ) + if sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + continue + } + sp.setFinalizedHeaderHashInIndexer(header.GetPrevHash()) - finalRootHash := scheduledHeaderRootHash - if len(finalRootHash) == 0 { - finalRootHash = header.GetRootHash() - } + sp.setFinalBlockInfo(header, headerHash, scheduledHeaderRootHash) + } - sp.blockChain.SetFinalBlockInfo(header.GetNonce(), headerHash, finalRootHash) + if !sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, currentHeader.GetEpoch()) { + return } + + sp.setFinalizedHeaderHashInIndexer(currentHeaderHash) + + scheduledHeaderRootHash, _ := sp.scheduledTxsExecutionHandler.GetScheduledRootHashForHeader(currentHeaderHash) + sp.setFinalBlockInfo(currentHeader, currentHeaderHash, scheduledHeaderRootHash) +} + +func (sp *shardProcessor) setFinalBlockInfo( + header data.HeaderHandler, + headerHash []byte, + scheduledHeaderRootHash []byte, +) { + finalRootHash := scheduledHeaderRootHash + if len(finalRootHash) == 0 { + finalRootHash = header.GetRootHash() + } + + sp.blockChain.SetFinalBlockInfo(header.GetNonce(), headerHash, finalRootHash) } func (sp *shardProcessor) snapShotEpochStartFromMeta(header data.ShardHeaderHandler) { @@ -1324,7 +1440,12 @@ func (sp *shardProcessor) checkEpochCorrectnessCrossChain() error { shouldRevertChain := false nonce := currentHeader.GetNonce() - shouldEnterNewEpochRound := sp.epochStartTrigger.EpochFinalityAttestingRound() + process.EpochChangeGracePeriod + gracePeriodForEpoch, err := sp.epochChangeGracePeriodHandler.GetGracePeriodForEpoch(sp.epochStartTrigger.MetaEpoch()) + if err != nil { + log.Debug("checkEpochCorrectnessCrossChain.GetGracePeriodForEpoch", "error", err.Error()) + return err + } + shouldEnterNewEpochRound := sp.epochStartTrigger.EpochFinalityAttestingRound() + uint64(gracePeriodForEpoch) for round := currentHeader.GetRound(); round > shouldEnterNewEpochRound && currentHeader.GetEpoch() < sp.epochStartTrigger.MetaEpoch(); round = currentHeader.GetRound() { if round <= lastFinalizedRound { @@ -1411,6 +1532,7 @@ func (sp *shardProcessor) CreateNewHeader(round uint64, nonce uint64) (data.Head } sp.roundNotifier.CheckRound(header) + sp.epochNotifier.CheckEpoch(header) err = shardHeader.SetNonce(nonce) if err != nil { @@ -1707,6 +1829,8 @@ func (sp *shardProcessor) receivedMetaBlock(headerHandler data.HeaderHandler, me hdrInfoForHash := sp.hdrsForCurrBlock.hdrHashAndInfo[string(metaBlockHash)] headerInfoIsNotNil := hdrInfoForHash != nil headerIsMissing := headerInfoIsNotNil && check.IfNil(hdrInfoForHash.hdr) + hasProof := headerInfoIsNotNil && hdrInfoForHash.hasProof + hasProofRequested := headerInfoIsNotNil && hdrInfoForHash.hasProofRequested if headerIsMissing { hdrInfoForHash.hdr = metaBlock sp.hdrsForCurrBlock.missingHdrs-- @@ -1714,14 +1838,15 @@ func (sp *shardProcessor) receivedMetaBlock(headerHandler data.HeaderHandler, me if metaBlock.Nonce > sp.hdrsForCurrBlock.highestHdrNonce[core.MetachainShardId] { sp.hdrsForCurrBlock.highestHdrNonce[core.MetachainShardId] = metaBlock.Nonce } + + if !hasProof && !hasProofRequested { + sp.requestProofIfNeeded(metaBlockHash, metaBlock) + } } - // attesting something if sp.hdrsForCurrBlock.missingHdrs == 0 { - sp.hdrsForCurrBlock.missingFinalityAttestingHdrs = sp.requestMissingFinalityAttestingHeaders( - core.MetachainShardId, - sp.metaBlockFinality, - ) + sp.checkFinalityRequestingMissing(metaBlock) + if sp.hdrsForCurrBlock.missingFinalityAttestingHdrs == 0 { log.Debug("received all missing finality attesting meta headers") } @@ -1729,11 +1854,12 @@ func (sp *shardProcessor) receivedMetaBlock(headerHandler data.HeaderHandler, me missingMetaHdrs := sp.hdrsForCurrBlock.missingHdrs missingFinalityAttestingMetaHdrs := sp.hdrsForCurrBlock.missingFinalityAttestingHdrs + missingProofs := sp.hdrsForCurrBlock.missingProofs sp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() - allMissingMetaHeadersReceived := missingMetaHdrs == 0 && missingFinalityAttestingMetaHdrs == 0 + allMissingMetaHeadersReceived := missingMetaHdrs == 0 && missingFinalityAttestingMetaHdrs == 0 && missingProofs == 0 if allMissingMetaHeadersReceived { - sp.chRcvAllMetaHdrs <- true + sp.chRcvAllHdrs <- true } } else { sp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() @@ -1742,21 +1868,34 @@ func (sp *shardProcessor) receivedMetaBlock(headerHandler data.HeaderHandler, me go sp.requestMiniBlocksIfNeeded(headerHandler) } -func (sp *shardProcessor) requestMetaHeaders(shardHeader data.ShardHeaderHandler) (uint32, uint32) { - _ = core.EmptyChannel(sp.chRcvAllMetaHdrs) +func (sp *shardProcessor) checkFinalityRequestingMissing(metaBlock *block.MetaBlock) { + shouldConsiderProofsForNotarization := sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, metaBlock.Epoch) + if !shouldConsiderProofsForNotarization { + sp.hdrsForCurrBlock.missingFinalityAttestingHdrs = sp.requestMissingFinalityAttestingHeaders( + core.MetachainShardId, + sp.metaBlockFinality, + ) + + return // no proof needed + } +} + +func (sp *shardProcessor) requestMetaHeaders(shardHeader data.ShardHeaderHandler) (uint32, uint32, uint32) { + _ = core.EmptyChannel(sp.chRcvAllHdrs) if len(shardHeader.GetMetaBlockHashes()) == 0 { - return 0, 0 + return 0, 0, 0 } return sp.computeExistingAndRequestMissingMetaHeaders(shardHeader) } -func (sp *shardProcessor) computeExistingAndRequestMissingMetaHeaders(header data.ShardHeaderHandler) (uint32, uint32) { +func (sp *shardProcessor) computeExistingAndRequestMissingMetaHeaders(header data.ShardHeaderHandler) (uint32, uint32, uint32) { sp.hdrsForCurrBlock.mutHdrsForBlock.Lock() defer sp.hdrsForCurrBlock.mutHdrsForBlock.Unlock() metaBlockHashes := header.GetMetaBlockHashes() + lastMetablockNonceWithProof := uint64(0) for i := 0; i < len(metaBlockHashes); i++ { hdr, err := process.GetMetaHeaderFromPool( metaBlockHashes[i], @@ -1765,9 +1904,12 @@ func (sp *shardProcessor) computeExistingAndRequestMissingMetaHeaders(header dat if err != nil { sp.hdrsForCurrBlock.missingHdrs++ sp.hdrsForCurrBlock.hdrHashAndInfo[string(metaBlockHashes[i])] = &hdrInfo{ - hdr: nil, - usedInBlock: true, + usedInBlock: true, + hdr: nil, + hasProof: false, + hasProofRequested: false, } + go sp.requestHandler.RequestMetaHeader(metaBlockHashes[i]) continue } @@ -1777,19 +1919,32 @@ func (sp *shardProcessor) computeExistingAndRequestMissingMetaHeaders(header dat usedInBlock: true, } - if hdr.Nonce > sp.hdrsForCurrBlock.highestHdrNonce[core.MetachainShardId] { - sp.hdrsForCurrBlock.highestHdrNonce[core.MetachainShardId] = hdr.Nonce + if hdr.GetNonce() > sp.hdrsForCurrBlock.highestHdrNonce[core.MetachainShardId] { + sp.hdrsForCurrBlock.highestHdrNonce[core.MetachainShardId] = hdr.GetNonce() + } + + shouldConsiderProofsForNotarization := sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, hdr.GetEpoch()) + if !shouldConsiderProofsForNotarization { + continue + } + + sp.requestProofIfNeeded(metaBlockHashes[i], hdr) + + sp.hdrsForCurrBlock.hdrHashAndInfo[string(metaBlockHashes[i])].hasProofRequested = true + + if hdr.GetNonce() > lastMetablockNonceWithProof { + lastMetablockNonceWithProof = hdr.GetNonce() } } - if sp.hdrsForCurrBlock.missingHdrs == 0 { + if sp.hdrsForCurrBlock.missingHdrs == 0 && lastMetablockNonceWithProof == 0 { sp.hdrsForCurrBlock.missingFinalityAttestingHdrs = sp.requestMissingFinalityAttestingHeaders( core.MetachainShardId, sp.metaBlockFinality, ) } - return sp.hdrsForCurrBlock.missingHdrs, sp.hdrsForCurrBlock.missingFinalityAttestingHdrs + return sp.hdrsForCurrBlock.missingHdrs, sp.hdrsForCurrBlock.missingFinalityAttestingHdrs, sp.hdrsForCurrBlock.missingProofs } func (sp *shardProcessor) verifyCrossShardMiniBlockDstMe(header data.ShardHeaderHandler) error { @@ -1910,9 +2065,21 @@ func (sp *shardProcessor) createAndProcessMiniBlocksDstMe(haveTime func() bool) break } + hasProofForHdr := sp.proofsPool.HasProof(core.MetachainShardId, orderedMetaBlocksHashes[i]) + shouldConsiderProofsForNotarization := sp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, orderedMetaBlocks[i].GetEpoch()) + if !hasProofForHdr && shouldConsiderProofsForNotarization { + log.Trace("no proof for meta header", + "hash", logger.DisplayByteSlice(orderedMetaBlocksHashes[i]), + ) + break + } + createAndProcessInfo.currMetaHdrHash = orderedMetaBlocksHashes[i] if len(createAndProcessInfo.currMetaHdr.GetMiniBlockHeadersWithDst(sp.shardCoordinator.SelfId())) == 0 { - sp.hdrsForCurrBlock.hdrHashAndInfo[string(createAndProcessInfo.currMetaHdrHash)] = &hdrInfo{hdr: createAndProcessInfo.currMetaHdr, usedInBlock: true} + sp.hdrsForCurrBlock.hdrHashAndInfo[string(createAndProcessInfo.currMetaHdrHash)] = &hdrInfo{ + hdr: createAndProcessInfo.currMetaHdr, + usedInBlock: true, + } createAndProcessInfo.numHdrsAdded++ lastMetaHdr = createAndProcessInfo.currMetaHdr continue @@ -1976,7 +2143,10 @@ func (sp *shardProcessor) createMbsAndProcessCrossShardTransactionsDstMe( createAndProcessInfo.numTxsAdded += currNumTxsAdded if !createAndProcessInfo.hdrAdded && currNumTxsAdded > 0 { - sp.hdrsForCurrBlock.hdrHashAndInfo[string(createAndProcessInfo.currMetaHdrHash)] = &hdrInfo{hdr: createAndProcessInfo.currMetaHdr, usedInBlock: true} + sp.hdrsForCurrBlock.hdrHashAndInfo[string(createAndProcessInfo.currMetaHdrHash)] = &hdrInfo{ + hdr: createAndProcessInfo.currMetaHdr, + usedInBlock: true, + } createAndProcessInfo.numHdrsAdded++ createAndProcessInfo.hdrAdded = true } @@ -2190,8 +2360,12 @@ func (sp *shardProcessor) applyBodyToHeader( } sw.Start("sortHeaderHashesForCurrentBlockByNonce") - metaBlockHashes := sp.sortHeaderHashesForCurrentBlockByNonce(true) + metaBlockHashes, err := sp.sortHeaderHashesForCurrentBlockByNonce(true) sw.Stop("sortHeaderHashesForCurrentBlockByNonce") + if err != nil { + return nil, err + } + err = shardHeader.SetMetaBlockHashes(metaBlockHashes[core.MetachainShardId]) if err != nil { return nil, err @@ -2216,7 +2390,7 @@ func (sp *shardProcessor) applyBodyToHeader( func (sp *shardProcessor) waitForMetaHdrHashes(waitTime time.Duration) error { select { - case <-sp.chRcvAllMetaHdrs: + case <-sp.chRcvAllHdrs: return nil case <-time.After(waitTime): return process.ErrTimeIsOut diff --git a/process/block/shardblockRequest_test.go b/process/block/shardblockRequest_test.go index 2440c6ecba5..3ab3a0f942f 100644 --- a/process/block/shardblockRequest_test.go +++ b/process/block/shardblockRequest_test.go @@ -116,11 +116,12 @@ func TestShardProcessor_computeExistingAndRequestMissingMetaHeaders(t *testing.T blockBeingProcessed := shard1Data.headerData[1].header shardBlockBeingProcessed := blockBeingProcessed.(*block.Header) - missingHeaders, missingFinalityAttestingHeaders := sp.ComputeExistingAndRequestMissingMetaHeaders(shardBlockBeingProcessed) + missingHeaders, missingFinalityAttestingHeaders, missingProofs := sp.ComputeExistingAndRequestMissingMetaHeaders(shardBlockBeingProcessed) time.Sleep(100 * time.Millisecond) require.Equal(t, uint32(1), missingHeaders) require.Equal(t, uint32(0), missingFinalityAttestingHeaders) + require.Equal(t, uint32(0), missingProofs) require.Equal(t, uint32(1), numCalls.Load()) }) t.Run("multiple referenced metaBlocks missing will be requested", func(t *testing.T) { @@ -152,11 +153,12 @@ func TestShardProcessor_computeExistingAndRequestMissingMetaHeaders(t *testing.T blockBeingProcessed := shard1Data.headerData[1].header shardBlockBeingProcessed := blockBeingProcessed.(*block.Header) - missingHeaders, missingFinalityAttestingHeaders := sp.ComputeExistingAndRequestMissingMetaHeaders(shardBlockBeingProcessed) + missingHeaders, missingFinalityAttestingHeaders, missingProofs := sp.ComputeExistingAndRequestMissingMetaHeaders(shardBlockBeingProcessed) time.Sleep(100 * time.Millisecond) require.Equal(t, uint32(2), missingHeaders) require.Equal(t, uint32(0), missingFinalityAttestingHeaders) + require.Equal(t, uint32(0), missingProofs) require.Equal(t, uint32(2), numCalls.Load()) }) t.Run("all referenced metaBlocks existing with missing attestation, will request the attestation metaBlock", func(t *testing.T) { @@ -191,11 +193,12 @@ func TestShardProcessor_computeExistingAndRequestMissingMetaHeaders(t *testing.T blockBeingProcessed := shard1Data.headerData[1].header shardBlockBeingProcessed := blockBeingProcessed.(*block.Header) - missingHeaders, missingFinalityAttestingHeaders := sp.ComputeExistingAndRequestMissingMetaHeaders(shardBlockBeingProcessed) + missingHeaders, missingFinalityAttestingHeaders, missingProofs := sp.ComputeExistingAndRequestMissingMetaHeaders(shardBlockBeingProcessed) time.Sleep(100 * time.Millisecond) require.Equal(t, uint32(0), missingHeaders) require.Equal(t, uint32(1), missingFinalityAttestingHeaders) + require.Equal(t, uint32(0), missingProofs) require.Equal(t, uint32(0), numCallsMissing.Load()) require.Equal(t, uint32(1), numCallsAttestation.Load()) }) @@ -234,11 +237,12 @@ func TestShardProcessor_computeExistingAndRequestMissingMetaHeaders(t *testing.T blockBeingProcessed := shard1Data.headerData[1].header shardBlockBeingProcessed := blockBeingProcessed.(*block.Header) - missingHeaders, missingFinalityAttestingHeaders := sp.ComputeExistingAndRequestMissingMetaHeaders(shardBlockBeingProcessed) + missingHeaders, missingFinalityAttestingHeaders, missingProofs := sp.ComputeExistingAndRequestMissingMetaHeaders(shardBlockBeingProcessed) time.Sleep(100 * time.Millisecond) require.Equal(t, uint32(0), missingHeaders) require.Equal(t, uint32(0), missingFinalityAttestingHeaders) + require.Equal(t, uint32(0), missingProofs) require.Equal(t, uint32(0), numCallsMissing.Load()) require.Equal(t, uint32(0), numCallsAttestation.Load()) }) diff --git a/process/block/shardblock_test.go b/process/block/shardblock_test.go index f6caf783286..c2855799452 100644 --- a/process/block/shardblock_test.go +++ b/process/block/shardblock_test.go @@ -174,6 +174,9 @@ func TestNewShardProcessor(t *testing.T) { HeadersCalled: func() dataRetriever.HeadersPool { return nil }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } return CreateMockArgumentsMultiShard(coreComponents, &dataCompCopy, bootstrapComponents, statusComponents) }, @@ -1433,6 +1436,9 @@ func TestShardProcessor_RequestEpochStartInfo(t *testing.T) { TransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{} }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } args := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) @@ -1485,6 +1491,9 @@ func TestShardProcessor_RequestEpochStartInfo(t *testing.T) { TransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{} }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } args := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) @@ -1543,6 +1552,9 @@ func TestShardProcessor_RequestEpochStartInfo(t *testing.T) { TransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{} }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } args := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) @@ -4516,7 +4528,7 @@ func TestShardProcessor_updateStateStorage(t *testing.T) { hdr1 := &block.Header{Nonce: 0, Round: 0} hdr2 := &block.Header{Nonce: 1, Round: 1} finalHeaders = append(finalHeaders, hdr1, hdr2) - sp.UpdateStateStorage(finalHeaders, &block.Header{}) + sp.UpdateStateStorage(finalHeaders, &block.Header{}, []byte("hash")) assert.True(t, pruneTrieWasCalled) assert.True(t, cancelPruneWasCalled) @@ -4591,7 +4603,9 @@ func TestShardProcessor_checkEpochCorrectnessCrossChainInCorrectEpochStorageErro }, } - header := &block.Header{Epoch: epochStartTrigger.Epoch() - 1, Round: epochStartTrigger.EpochFinalityAttestingRound() + process.EpochChangeGracePeriod + 1} + coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() + gracePeriod, _ := coreComponents.EpochChangeGracePeriodHandlerField.GetGracePeriodForEpoch(epochStartTrigger.Epoch()) + header := &block.Header{Epoch: epochStartTrigger.Epoch() - 1, Round: epochStartTrigger.EpochFinalityAttestingRound() + uint64(gracePeriod) + 1} blockChain := &testscommon.ChainHandlerStub{ GetCurrentBlockHeaderCalled: func() data.HeaderHandler { return header @@ -4600,7 +4614,6 @@ func TestShardProcessor_checkEpochCorrectnessCrossChainInCorrectEpochStorageErro return &block.Header{Nonce: 0} }, } - coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() dataComponents.BlockChain = blockChain arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) arguments.EpochStartTrigger = epochStartTrigger @@ -4639,11 +4652,15 @@ func TestShardProcessor_checkEpochCorrectnessCrossChainInCorrectEpochRollback1Bl forkDetector := &mock.ForkDetectorMock{SetRollBackNonceCalled: func(nonce uint64) { nonceCalled = nonce }} + + coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() + dataComponents.Storage = store prevHash := []byte("prevHash") + gracePeriod, _ := coreComponents.EpochChangeGracePeriodHandlerField.GetGracePeriodForEpoch(epochStartTrigger.Epoch()) currHeader := &block.Header{ Nonce: 10, Epoch: epochStartTrigger.Epoch() - 1, - Round: epochStartTrigger.EpochFinalityAttestingRound() + process.EpochChangeGracePeriod + 1, + Round: epochStartTrigger.EpochFinalityAttestingRound() + uint64(gracePeriod) + 1, PrevHash: prevHash} blockChain := &testscommon.ChainHandlerStub{ @@ -4654,9 +4671,6 @@ func TestShardProcessor_checkEpochCorrectnessCrossChainInCorrectEpochRollback1Bl return &block.Header{Nonce: 0} }, } - - coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() - dataComponents.Storage = store dataComponents.BlockChain = blockChain arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) arguments.EpochStartTrigger = epochStartTrigger @@ -4667,7 +4681,7 @@ func TestShardProcessor_checkEpochCorrectnessCrossChainInCorrectEpochRollback1Bl prevHeader := &block.Header{ Nonce: 8, Epoch: epochStartTrigger.Epoch() - 1, - Round: epochStartTrigger.EpochFinalityAttestingRound() + process.EpochChangeGracePeriod, + Round: epochStartTrigger.EpochFinalityAttestingRound() + uint64(gracePeriod), } prevHeaderData, _ := coreComponents.InternalMarshalizer().Marshal(prevHeader) @@ -4702,11 +4716,15 @@ func TestShardProcessor_checkEpochCorrectnessCrossChainInCorrectEpochRollback2Bl forkDetector := &mock.ForkDetectorMock{SetRollBackNonceCalled: func(nonce uint64) { nonceCalled = nonce }} + + coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() + dataComponents.Storage = store + gracePeriod, _ := coreComponents.EpochChangeGracePeriodHandlerField.GetGracePeriodForEpoch(epochStartTrigger.Epoch()) prevHash := []byte("prevHash") header := &block.Header{ Nonce: 10, Epoch: epochStartTrigger.Epoch() - 1, - Round: epochStartTrigger.EpochFinalityAttestingRound() + process.EpochChangeGracePeriod + 2, + Round: epochStartTrigger.EpochFinalityAttestingRound() + uint64(gracePeriod) + 2, PrevHash: prevHash} blockChain := &testscommon.ChainHandlerStub{ @@ -4717,9 +4735,6 @@ func TestShardProcessor_checkEpochCorrectnessCrossChainInCorrectEpochRollback2Bl return &block.Header{Nonce: 0} }, } - - coreComponents, dataComponents, bootstrapComponents, statusComponents := createComponentHolderMocks() - dataComponents.Storage = store dataComponents.BlockChain = blockChain arguments := CreateMockArguments(coreComponents, dataComponents, bootstrapComponents, statusComponents) arguments.EpochStartTrigger = epochStartTrigger @@ -4731,7 +4746,7 @@ func TestShardProcessor_checkEpochCorrectnessCrossChainInCorrectEpochRollback2Bl prevHeader := &block.Header{ Nonce: 8, Epoch: epochStartTrigger.Epoch() - 1, - Round: epochStartTrigger.EpochFinalityAttestingRound() + process.EpochChangeGracePeriod + 1, + Round: epochStartTrigger.EpochFinalityAttestingRound() + uint64(gracePeriod) + 1, PrevHash: prevPrevHash, } prevHeaderData, _ := coreComponents.InternalMarshalizer().Marshal(prevHeader) @@ -4745,7 +4760,7 @@ func TestShardProcessor_checkEpochCorrectnessCrossChainInCorrectEpochRollback2Bl prevPrevHeader := &block.Header{ Nonce: 7, Epoch: epochStartTrigger.Epoch() - 1, - Round: epochStartTrigger.EpochFinalityAttestingRound() + process.EpochChangeGracePeriod, + Round: epochStartTrigger.EpochFinalityAttestingRound() + uint64(gracePeriod), PrevHash: prevPrevHash, } prevPrevHeaderData, _ := coreComponents.InternalMarshalizer().Marshal(prevPrevHeader) @@ -4929,6 +4944,9 @@ func TestShardProcessor_CheckEpochCorrectnessShouldRemoveAndRequestStartOfEpochM }, } }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } ch := make(chan struct{}) diff --git a/process/common.go b/process/common.go index e8c9c7504ff..198645ea6ab 100644 --- a/process/common.go +++ b/process/common.go @@ -18,10 +18,11 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" - "github.com/multiversx/mx-chain-go/dataRetriever" - "github.com/multiversx/mx-chain-go/state" logger "github.com/multiversx/mx-chain-logger-go" vmcommon "github.com/multiversx/mx-chain-vm-common-go" + + "github.com/multiversx/mx-chain-go/dataRetriever" + "github.com/multiversx/mx-chain-go/state" ) var log = logger.GetOrCreate("process") @@ -345,7 +346,7 @@ func GetShardHeaderFromStorageWithNonce( storageService, uint64Converter, marshalizer, - dataRetriever.ShardHdrNonceHashDataUnit+dataRetriever.UnitType(shardId)) + dataRetriever.GetHdrNonceHashDataUnit(shardId)) if err != nil { return nil, nil, err } @@ -774,6 +775,21 @@ func GetSortedStorageUpdates(account *vmcommon.OutputAccount) []*vmcommon.Storag return storageUpdates } +// GetHeader tries to get the header from pool first and if not found, searches for it through storer +func GetHeader( + headerHash []byte, + headersPool dataRetriever.HeadersPool, + headersStorer dataRetriever.StorageService, + marshaller marshal.Marshalizer, + shardID uint32, +) (data.HeaderHandler, error) { + if shardID == core.MetachainShardId { + return GetMetaHeader(headerHash, headersPool, marshaller, headersStorer) + } + + return GetShardHeader(headerHash, headersPool, marshaller, headersStorer) +} + // UnmarshalHeader unmarshalls a block header func UnmarshalHeader(shardId uint32, marshalizer marshal.Marshalizer, headerBuffer []byte) (data.HeaderHandler, error) { if shardId == core.MetachainShardId { diff --git a/process/common_test.go b/process/common_test.go index a79e2fd5c32..b6e308ec3ab 100644 --- a/process/common_test.go +++ b/process/common_test.go @@ -12,14 +12,16 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-core-go/data/typeConverters" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestGetShardHeaderShouldErrNilCacher(t *testing.T) { @@ -1800,7 +1802,7 @@ func TestGetTransactionHandlerShouldGetTransactionFromPool(t *testing.T) { storageService := &storageStubs.ChainStorerStub{} shardedDataCacherNotifier := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return txFromPool, true }, @@ -1843,7 +1845,7 @@ func TestGetTransactionHandlerShouldGetTransactionFromStorage(t *testing.T) { } shardedDataCacherNotifier := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -1871,7 +1873,7 @@ func TestGetTransactionHandlerFromPool_Errors(t *testing.T) { shardedDataCacherNotifier := testscommon.NewShardedDataStub() shardedDataCacherNotifier.ShardDataStoreCalled = func(cacheID string) storage.Cacher { - return testscommon.NewCacherMock() + return cache.NewCacherMock() } t.Run("nil sharded cache", func(t *testing.T) { @@ -1922,7 +1924,7 @@ func TestGetTransactionHandlerFromPoolShouldErrTxNotFound(t *testing.T) { shardedDataCacherNotifier := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -1948,7 +1950,7 @@ func TestGetTransactionHandlerFromPoolShouldErrInvalidTxInPool(t *testing.T) { shardedDataCacherNotifier := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return nil, true }, @@ -1975,7 +1977,7 @@ func TestGetTransactionHandlerFromPoolShouldWorkWithPeek(t *testing.T) { shardedDataCacherNotifier := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return txFromPool, true }, @@ -2026,7 +2028,7 @@ func TestGetTransactionHandlerFromPoolShouldWorkWithPeekFallbackToSearchFirst(t peekCalled := false shardedDataCacherNotifier := &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheId string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { peekCalled = true return nil, false diff --git a/process/constants.go b/process/constants.go index 997e0a2a458..8468feb302b 100644 --- a/process/constants.go +++ b/process/constants.go @@ -73,9 +73,6 @@ const BlockFinality = 1 // MetaBlockValidity defines the block validity which is when checking a metablock const MetaBlockValidity = 1 -// EpochChangeGracePeriod defines the allowed round numbers till the shard has to change the epoch -const EpochChangeGracePeriod = 1 - // MaxHeaderRequestsAllowed defines the maximum number of missing cross-shard headers (gaps) which could be requested // in one round, when node processes a received block const MaxHeaderRequestsAllowed = 20 diff --git a/process/coordinator/process_test.go b/process/coordinator/process_test.go index 5bdd8e086b1..be4150f01b9 100644 --- a/process/coordinator/process_test.go +++ b/process/coordinator/process_test.go @@ -37,6 +37,7 @@ import ( "github.com/multiversx/mx-chain-go/storage/database" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" commonMock "github.com/multiversx/mx-chain-go/testscommon/common" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" @@ -82,7 +83,7 @@ func createShardedDataChacherNotifier( return &testscommon.ShardedDataStub{ RegisterOnAddedCalled: func(i func(key []byte, value interface{})) {}, ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, testHash) { return handler, true @@ -127,7 +128,7 @@ func initDataPool(testHash []byte) *dataRetrieverMock.PoolsHolderStub { UnsignedTransactionsCalled: unsignedTxHandler, RewardTransactionsCalled: rewardTxCalled, MetaBlocksCalled: func() storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, []byte("tx1_hash")) { return &transaction.Transaction{Nonce: 10}, true @@ -150,7 +151,7 @@ func initDataPool(testHash []byte) *dataRetrieverMock.PoolsHolderStub { } }, MiniBlocksCalled: func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -1164,7 +1165,7 @@ func TestTransactionCoordinator_CreateMbsAndProcessTransactionsFromMeNothingToPr shardedCacheMock := &testscommon.ShardedDataStub{ RegisterOnAddedCalled: func(i func(key []byte, value interface{})) {}, ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -2360,7 +2361,7 @@ func TestTransactionCoordinator_VerifyCreatedBlockTransactionsOk(t *testing.T) { return &testscommon.ShardedDataStub{ RegisterOnAddedCalled: func(i func(key []byte, value interface{})) {}, ShardDataStoreCalled: func(id string) (c storage.Cacher) { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if reflect.DeepEqual(key, scrHash) { return scr, true @@ -4542,7 +4543,7 @@ func TestTransactionCoordinator_requestMissingMiniBlocksAndTransactionsShouldWor t.Parallel() args := createMockTransactionCoordinatorArguments() - args.MiniBlockPool = &testscommon.CacherStub{ + args.MiniBlockPool = &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if bytes.Equal(key, []byte("hash0")) || bytes.Equal(key, []byte("hash1")) || bytes.Equal(key, []byte("hash2")) { if bytes.Equal(key, []byte("hash0")) { diff --git a/process/errors.go b/process/errors.go index 060bbdd5009..d3cc22dadcb 100644 --- a/process/errors.go +++ b/process/errors.go @@ -239,6 +239,9 @@ var ErrNilMiniBlockPool = errors.New("nil mini block pool") // ErrNilMetaBlocksPool signals that a nil meta blocks pool was used var ErrNilMetaBlocksPool = errors.New("nil meta blocks pool") +// ErrNilProofsPool signals that a nil proofs pool was used +var ErrNilProofsPool = errors.New("nil proofs pool") + // ErrNilTxProcessor signals that a nil transactions processor was used var ErrNilTxProcessor = errors.New("nil transactions processor") @@ -465,6 +468,9 @@ var ErrNilEpochNotifier = errors.New("nil EpochNotifier") // ErrNilRoundNotifier signals that the provided EpochNotifier is nil var ErrNilRoundNotifier = errors.New("nil RoundNotifier") +// ErrNilChainParametersHandler signals that the provided chain parameters handler is nil +var ErrNilChainParametersHandler = errors.New("nil chain parameters handler") + // ErrInvalidCacheRefreshIntervalInSec signals that the cacheRefreshIntervalInSec is invalid - zero or less var ErrInvalidCacheRefreshIntervalInSec = errors.New("invalid cacheRefreshIntervalInSec") @@ -696,6 +702,9 @@ var ErrNilWhiteListHandler = errors.New("nil whitelist handler") // ErrNilPreferredPeersHolder signals that preferred peers holder is nil var ErrNilPreferredPeersHolder = errors.New("nil preferred peers holder") +// ErrNilInterceptedDataVerifier signals that intercepted data verifier is nil +var ErrNilInterceptedDataVerifier = errors.New("nil intercepted data verifier") + // ErrMiniBlocksInWrongOrder signals the miniblocks are in wrong order var ErrMiniBlocksInWrongOrder = errors.New("miniblocks in wrong order, should have been only from me") @@ -1092,6 +1101,9 @@ var ErrInvalidExpiryTimespan = errors.New("invalid expiry timespan") // ErrNilPeerSignatureHandler signals that a nil peer signature handler was provided var ErrNilPeerSignatureHandler = errors.New("nil peer signature handler") +// ErrNilInterceptedDataVerifierFactory signals that a nil intercepted data verifier factory was provided +var ErrNilInterceptedDataVerifierFactory = errors.New("nil intercepted data verifier factory") + // ErrNilPeerAuthenticationCacher signals that a nil peer authentication cacher was provided var ErrNilPeerAuthenticationCacher = errors.New("nil peer authentication cacher") @@ -1137,6 +1149,9 @@ var ErrNilESDTGlobalSettingsHandler = errors.New("nil esdt global settings handl // ErrNilEnableEpochsHandler signals that a nil enable epochs handler has been provided var ErrNilEnableEpochsHandler = errors.New("nil enable epochs handler") +// ErrNilEpochChangeGracePeriodHandler signals that a nil epoch change grace period handler has been provided +var ErrNilEpochChangeGracePeriodHandler = errors.New("nil epoch change grace period handler") + // ErrNilMultiSignerContainer signals that the given multisigner container is nil var ErrNilMultiSignerContainer = errors.New("nil multiSigner container") @@ -1233,6 +1248,15 @@ var ErrTransferAndExecuteByUserAddressesAreNil = errors.New("transfer and execut // ErrRelayedTxV3Disabled signals that relayed tx v3 are disabled var ErrRelayedTxV3Disabled = errors.New("relayed tx v3 are disabled") +// ErrMissingConfigurationForEpochZero signals that the provided configuration doesn't include anything for epoch 0 +var ErrMissingConfigurationForEpochZero = errors.New("missing configuration for epoch 0") + +// ErrEmptyChainParametersConfiguration signals that an empty chain parameters configuration has been provided +var ErrEmptyChainParametersConfiguration = errors.New("empty chain parameters configuration") + +// ErrNoMatchingConfigForProvidedEpoch signals that there is no matching configuration for the provided epoch +var ErrNoMatchingConfigForProvidedEpoch = errors.New("no matching configuration") + // ErrGuardedRelayerNotAllowed signals that the provided relayer is guarded var ErrGuardedRelayerNotAllowed = errors.New("guarded relayer not allowed") @@ -1244,3 +1268,33 @@ var ErrInvalidRelayedTxV3 = errors.New("invalid relayed transaction") // ErrProtocolSustainabilityAddressInMetachain signals that protocol sustainability address is in metachain which is not allowed var ErrProtocolSustainabilityAddressInMetachain = errors.New("protocol sustainability address in metachain") + +// ErrNilHeaderProof signals that a nil header proof has been provided +var ErrNilHeaderProof = errors.New("nil header proof") + +// ErrNilInterceptedDataCache signals that a nil cacher was provided for intercepted data verifier +var ErrNilInterceptedDataCache = errors.New("nil cache for intercepted data") + +// ErrFlagNotActive signals that a flag is not active +var ErrFlagNotActive = errors.New("flag not active") + +// ErrInvalidInterceptedData signals that an invalid data has been intercepted +var ErrInvalidInterceptedData = errors.New("invalid intercepted data") + +// ErrMissingHeaderProof signals that the proof for the header is missing +var ErrMissingHeaderProof = errors.New("missing header proof") + +// ErrInvalidHeaderProof signals that an invalid equivalent proof has been provided +var ErrInvalidHeaderProof = errors.New("invalid equivalent proof") + +// ErrUnexpectedHeaderProof signals that a header proof has been provided unexpectedly +var ErrUnexpectedHeaderProof = errors.New("unexpected header proof") + +// ErrEpochMismatch signals that the epoch do not match +var ErrEpochMismatch = errors.New("epoch mismatch") + +// ErrInvalidRatingsConfig signals that an invalid ratings config has been provided +var ErrInvalidRatingsConfig = errors.New("invalid ratings config") + +// ErrNilKeyRWMutexHandler signals that a nil KeyRWMutexHandler has been provided +var ErrNilKeyRWMutexHandler = errors.New("nil key rw mutex handler") diff --git a/process/factory/interceptorscontainer/args.go b/process/factory/interceptorscontainer/args.go index 294e66290b3..8e98c7c18ab 100644 --- a/process/factory/interceptorscontainer/args.go +++ b/process/factory/interceptorscontainer/args.go @@ -2,6 +2,7 @@ package interceptorscontainer import ( crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/heartbeat" @@ -13,34 +14,35 @@ import ( // CommonInterceptorsContainerFactoryArgs holds the arguments needed for the metachain/shard interceptors factories type CommonInterceptorsContainerFactoryArgs struct { - CoreComponents process.CoreComponentsHolder - CryptoComponents process.CryptoComponentsHolder - Accounts state.AccountsAdapter - ShardCoordinator sharding.Coordinator - NodesCoordinator nodesCoordinator.NodesCoordinator - MainMessenger process.TopicHandler - FullArchiveMessenger process.TopicHandler - Store dataRetriever.StorageService - DataPool dataRetriever.PoolsHolder - MaxTxNonceDeltaAllowed int - TxFeeHandler process.FeeHandler - BlockBlackList process.TimeCacher - HeaderSigVerifier process.InterceptedHeaderSigVerifier - HeaderIntegrityVerifier process.HeaderIntegrityVerifier - ValidityAttester process.ValidityAttester - EpochStartTrigger process.EpochStartTriggerHandler - WhiteListHandler process.WhiteListHandler - WhiteListerVerifiedTxs process.WhiteListHandler - AntifloodHandler process.P2PAntifloodHandler - ArgumentsParser process.ArgumentsParser - PreferredPeersHolder process.PreferredPeersHolderHandler - SizeCheckDelta uint32 - RequestHandler process.RequestHandler - PeerSignatureHandler crypto.PeerSignatureHandler - SignaturesHandler process.SignaturesHandler - HeartbeatExpiryTimespanInSec int64 - MainPeerShardMapper process.PeerShardMapper - FullArchivePeerShardMapper process.PeerShardMapper - HardforkTrigger heartbeat.HardforkTrigger - NodeOperationMode common.NodeOperation + CoreComponents process.CoreComponentsHolder + CryptoComponents process.CryptoComponentsHolder + Accounts state.AccountsAdapter + ShardCoordinator sharding.Coordinator + NodesCoordinator nodesCoordinator.NodesCoordinator + MainMessenger process.TopicHandler + FullArchiveMessenger process.TopicHandler + Store dataRetriever.StorageService + DataPool dataRetriever.PoolsHolder + MaxTxNonceDeltaAllowed int + TxFeeHandler process.FeeHandler + BlockBlackList process.TimeCacher + HeaderSigVerifier process.InterceptedHeaderSigVerifier + HeaderIntegrityVerifier process.HeaderIntegrityVerifier + ValidityAttester process.ValidityAttester + EpochStartTrigger process.EpochStartTriggerHandler + WhiteListHandler process.WhiteListHandler + WhiteListerVerifiedTxs process.WhiteListHandler + AntifloodHandler process.P2PAntifloodHandler + ArgumentsParser process.ArgumentsParser + PreferredPeersHolder process.PreferredPeersHolderHandler + SizeCheckDelta uint32 + RequestHandler process.RequestHandler + PeerSignatureHandler crypto.PeerSignatureHandler + SignaturesHandler process.SignaturesHandler + HeartbeatExpiryTimespanInSec int64 + MainPeerShardMapper process.PeerShardMapper + FullArchivePeerShardMapper process.PeerShardMapper + HardforkTrigger heartbeat.HardforkTrigger + NodeOperationMode common.NodeOperation + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory } diff --git a/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go b/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go index cfed22b39c9..bdd6ea118e1 100644 --- a/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go +++ b/process/factory/interceptorscontainer/baseInterceptorsContainerFactory.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/heartbeat" @@ -31,29 +32,31 @@ const ( ) type baseInterceptorsContainerFactory struct { - mainContainer process.InterceptorsContainer - fullArchiveContainer process.InterceptorsContainer - shardCoordinator sharding.Coordinator - accounts state.AccountsAdapter - store dataRetriever.StorageService - dataPool dataRetriever.PoolsHolder - mainMessenger process.TopicHandler - fullArchiveMessenger process.TopicHandler - nodesCoordinator nodesCoordinator.NodesCoordinator - blockBlackList process.TimeCacher - argInterceptorFactory *interceptorFactory.ArgInterceptedDataFactory - globalThrottler process.InterceptorThrottler - maxTxNonceDeltaAllowed int - antifloodHandler process.P2PAntifloodHandler - whiteListHandler process.WhiteListHandler - whiteListerVerifiedTxs process.WhiteListHandler - preferredPeersHolder process.PreferredPeersHolderHandler - hasher hashing.Hasher - requestHandler process.RequestHandler - mainPeerShardMapper process.PeerShardMapper - fullArchivePeerShardMapper process.PeerShardMapper - hardforkTrigger heartbeat.HardforkTrigger - nodeOperationMode common.NodeOperation + mainContainer process.InterceptorsContainer + fullArchiveContainer process.InterceptorsContainer + shardCoordinator sharding.Coordinator + accounts state.AccountsAdapter + store dataRetriever.StorageService + dataPool dataRetriever.PoolsHolder + mainMessenger process.TopicHandler + fullArchiveMessenger process.TopicHandler + nodesCoordinator nodesCoordinator.NodesCoordinator + blockBlackList process.TimeCacher + argInterceptorFactory *interceptorFactory.ArgInterceptedDataFactory + globalThrottler process.InterceptorThrottler + maxTxNonceDeltaAllowed int + antifloodHandler process.P2PAntifloodHandler + whiteListHandler process.WhiteListHandler + whiteListerVerifiedTxs process.WhiteListHandler + preferredPeersHolder process.PreferredPeersHolderHandler + hasher hashing.Hasher + requestHandler process.RequestHandler + mainPeerShardMapper process.PeerShardMapper + fullArchivePeerShardMapper process.PeerShardMapper + hardforkTrigger heartbeat.HardforkTrigger + nodeOperationMode common.NodeOperation + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory + enableEpochsHandler common.EnableEpochsHandler } func checkBaseParams( @@ -285,18 +288,25 @@ func (bicf *baseInterceptorsContainerFactory) createOneTxInterceptor(topic strin return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + internalMarshaller := bicf.argInterceptorFactory.CoreComponents.InternalMarshalizer() interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: internalMarshaller, - DataFactory: txFactory, - Processor: txProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: topic, + Marshalizer: internalMarshaller, + Hasher: bicf.argInterceptorFactory.CoreComponents.Hasher(), + DataFactory: txFactory, + Processor: txProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -328,18 +338,25 @@ func (bicf *baseInterceptorsContainerFactory) createOneUnsignedTxInterceptor(top return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + internalMarshaller := bicf.argInterceptorFactory.CoreComponents.InternalMarshalizer() interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: internalMarshaller, - DataFactory: txFactory, - Processor: txProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: topic, + Marshalizer: internalMarshaller, + Hasher: bicf.argInterceptorFactory.CoreComponents.Hasher(), + DataFactory: txFactory, + Processor: txProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -371,18 +388,25 @@ func (bicf *baseInterceptorsContainerFactory) createOneRewardTxInterceptor(topic return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + internalMarshaller := bicf.argInterceptorFactory.CoreComponents.InternalMarshalizer() interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: internalMarshaller, - DataFactory: txFactory, - Processor: txProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: topic, + Marshalizer: internalMarshaller, + Hasher: bicf.argInterceptorFactory.CoreComponents.Hasher(), + DataFactory: txFactory, + Processor: txProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -403,8 +427,10 @@ func (bicf *baseInterceptorsContainerFactory) generateHeaderInterceptors() error } argProcessor := &processor.ArgHdrInterceptorProcessor{ - Headers: bicf.dataPool.Headers(), - BlockBlackList: bicf.blockBlackList, + Headers: bicf.dataPool.Headers(), + BlockBlackList: bicf.blockBlackList, + Proofs: bicf.dataPool.Proofs(), + EnableEpochsHandler: bicf.enableEpochsHandler, } hdrProcessor, err := processor.NewHdrInterceptorProcessor(argProcessor) if err != nil { @@ -414,17 +440,23 @@ func (bicf *baseInterceptorsContainerFactory) generateHeaderInterceptors() error // compose header shard topic, for example: shardBlocks_0_META identifierHdr := factory.ShardBlocksTopic + shardC.CommunicationIdentifier(core.MetachainShardId) + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(identifierHdr) + if err != nil { + return err + } + // only one intrashard header topic interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: identifierHdr, - DataFactory: hdrFactory, - Processor: hdrProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: identifierHdr, + DataFactory: hdrFactory, + Processor: hdrProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -502,17 +534,24 @@ func (bicf *baseInterceptorsContainerFactory) createOneMiniBlocksInterceptor(top return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: internalMarshaller, - DataFactory: miniblockFactory, - Processor: miniblockProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: topic, + Marshalizer: internalMarshaller, + Hasher: hasher, + DataFactory: miniblockFactory, + Processor: miniblockProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -527,31 +566,42 @@ func (bicf *baseInterceptorsContainerFactory) createOneMiniBlocksInterceptor(top func (bicf *baseInterceptorsContainerFactory) generateMetachainHeaderInterceptors() error { identifierHdr := factory.MetachainBlocksTopic - hdrFactory, err := interceptorFactory.NewInterceptedMetaHeaderDataFactory(bicf.argInterceptorFactory) + argsInterceptedMetaHeaderFactory := interceptorFactory.ArgInterceptedMetaHeaderFactory{ + ArgInterceptedDataFactory: *bicf.argInterceptorFactory, + } + hdrFactory, err := interceptorFactory.NewInterceptedMetaHeaderDataFactory(&argsInterceptedMetaHeaderFactory) if err != nil { return err } argProcessor := &processor.ArgHdrInterceptorProcessor{ - Headers: bicf.dataPool.Headers(), - BlockBlackList: bicf.blockBlackList, + Headers: bicf.dataPool.Headers(), + BlockBlackList: bicf.blockBlackList, + Proofs: bicf.dataPool.Proofs(), + EnableEpochsHandler: bicf.enableEpochsHandler, } hdrProcessor, err := processor.NewHdrInterceptorProcessor(argProcessor) if err != nil { return err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(identifierHdr) + if err != nil { + return err + } + // only one metachain header topic interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: identifierHdr, - DataFactory: hdrFactory, - Processor: hdrProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: identifierHdr, + DataFactory: hdrFactory, + Processor: hdrProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -577,18 +627,25 @@ func (bicf *baseInterceptorsContainerFactory) createOneTrieNodesInterceptor(topi return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + internalMarshaller := bicf.argInterceptorFactory.CoreComponents.InternalMarshalizer() interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: internalMarshaller, - DataFactory: trieNodesFactory, - Processor: trieNodesProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: topic, + Marshalizer: internalMarshaller, + Hasher: bicf.argInterceptorFactory.CoreComponents.Hasher(), + DataFactory: trieNodesFactory, + Processor: trieNodesProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -647,7 +704,7 @@ func (bicf *baseInterceptorsContainerFactory) generateUnsignedTxsInterceptors() return bicf.addInterceptorsToContainers(keys, interceptorsSlice) } -//------- PeerAuthentication interceptor +// ------- PeerAuthentication interceptor func (bicf *baseInterceptorsContainerFactory) generatePeerAuthenticationInterceptor() error { identifierPeerAuthentication := common.PeerAuthenticationTopic @@ -669,17 +726,24 @@ func (bicf *baseInterceptorsContainerFactory) generatePeerAuthenticationIntercep return err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(identifierPeerAuthentication) + if err != nil { + return err + } + mdInterceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: identifierPeerAuthentication, - Marshalizer: internalMarshaller, - DataFactory: peerAuthenticationFactory, - Processor: peerAuthenticationProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - PreferredPeersHolder: bicf.preferredPeersHolder, - CurrentPeerId: bicf.mainMessenger.ID(), + Topic: identifierPeerAuthentication, + Marshalizer: internalMarshaller, + Hasher: bicf.argInterceptorFactory.CoreComponents.Hasher(), + DataFactory: peerAuthenticationFactory, + Processor: peerAuthenticationProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + PreferredPeersHolder: bicf.preferredPeersHolder, + CurrentPeerId: bicf.mainMessenger.ID(), + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -694,7 +758,7 @@ func (bicf *baseInterceptorsContainerFactory) generatePeerAuthenticationIntercep return bicf.mainContainer.Add(identifierPeerAuthentication, mdInterceptor) } -//------- Heartbeat interceptor +// ------- Heartbeat interceptor func (bicf *baseInterceptorsContainerFactory) generateHeartbeatInterceptor() error { shardC := bicf.shardCoordinator @@ -728,16 +792,22 @@ func (bicf *baseInterceptorsContainerFactory) createHeartbeatV2Interceptor( return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(identifier) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: identifier, - DataFactory: heartbeatFactory, - Processor: heartbeatProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - PreferredPeersHolder: bicf.preferredPeersHolder, - CurrentPeerId: bicf.mainMessenger.ID(), + Topic: identifier, + DataFactory: heartbeatFactory, + Processor: heartbeatProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + PreferredPeersHolder: bicf.preferredPeersHolder, + CurrentPeerId: bicf.mainMessenger.ID(), + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -777,16 +847,22 @@ func (bicf *baseInterceptorsContainerFactory) createPeerShardInterceptor( return nil, err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(identifier) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: identifier, - DataFactory: interceptedPeerShardFactory, - Processor: psiProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - CurrentPeerId: bicf.mainMessenger.ID(), - PreferredPeersHolder: bicf.preferredPeersHolder, + Topic: identifier, + DataFactory: interceptedPeerShardFactory, + Processor: psiProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -814,17 +890,24 @@ func (bicf *baseInterceptorsContainerFactory) generateValidatorInfoInterceptor() return err } + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(identifier) + if err != nil { + return err + } + mdInterceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: identifier, - Marshalizer: internalMarshaller, - DataFactory: interceptedValidatorInfoFactory, - Processor: validatorInfoProcessor, - Throttler: bicf.globalThrottler, - AntifloodHandler: bicf.antifloodHandler, - WhiteListRequest: bicf.whiteListHandler, - PreferredPeersHolder: bicf.preferredPeersHolder, - CurrentPeerId: bicf.mainMessenger.ID(), + Topic: identifier, + Marshalizer: internalMarshaller, + Hasher: bicf.argInterceptorFactory.CoreComponents.Hasher(), + DataFactory: interceptedValidatorInfoFactory, + Processor: validatorInfoProcessor, + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + PreferredPeersHolder: bicf.preferredPeersHolder, + CurrentPeerId: bicf.mainMessenger.ID(), + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -839,6 +922,38 @@ func (bicf *baseInterceptorsContainerFactory) generateValidatorInfoInterceptor() return bicf.addInterceptorsToContainers([]string{identifier}, []process.Interceptor{interceptor}) } +func (bicf *baseInterceptorsContainerFactory) createOneShardEquivalentProofsInterceptor(topic string) (process.Interceptor, error) { + args := interceptorFactory.ArgInterceptedEquivalentProofsFactory{ + ArgInterceptedDataFactory: *bicf.argInterceptorFactory, + ProofsPool: bicf.dataPool.Proofs(), + } + equivalentProofsFactory := interceptorFactory.NewInterceptedEquivalentProofsFactory(args) + + interceptedDataVerifier, err := bicf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + + interceptor, err := interceptors.NewSingleDataInterceptor( + interceptors.ArgSingleDataInterceptor{ + Topic: topic, + DataFactory: equivalentProofsFactory, + Processor: processor.NewEquivalentProofsInterceptorProcessor(), + Throttler: bicf.globalThrottler, + AntifloodHandler: bicf.antifloodHandler, + WhiteListRequest: bicf.whiteListHandler, + CurrentPeerId: bicf.mainMessenger.ID(), + PreferredPeersHolder: bicf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, + }, + ) + if err != nil { + return nil, err + } + + return bicf.createTopicAndAssignHandler(topic, interceptor, true) +} + func (bicf *baseInterceptorsContainerFactory) addInterceptorsToContainers(keys []string, interceptors []process.Interceptor) error { err := bicf.mainContainer.AddMultiple(keys, interceptors) if err != nil { diff --git a/process/factory/interceptorscontainer/metaInterceptorsContainerFactory.go b/process/factory/interceptorscontainer/metaInterceptorsContainerFactory.go index 38d3e460bce..8f6b8fc6b0a 100644 --- a/process/factory/interceptorscontainer/metaInterceptorsContainerFactory.go +++ b/process/factory/interceptorscontainer/metaInterceptorsContainerFactory.go @@ -5,6 +5,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/throttler" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/process/factory/containers" @@ -79,6 +80,9 @@ func NewMetaInterceptorsContainerFactory( if check.IfNil(args.PeerSignatureHandler) { return nil, process.ErrNilPeerSignatureHandler } + if check.IfNil(args.InterceptedDataVerifierFactory) { + return nil, process.ErrNilInterceptedDataVerifierFactory + } if args.HeartbeatExpiryTimespanInSec < minTimespanDurationInSec { return nil, process.ErrInvalidExpiryTimespan } @@ -102,28 +106,30 @@ func NewMetaInterceptorsContainerFactory( } base := &baseInterceptorsContainerFactory{ - mainContainer: containers.NewInterceptorsContainer(), - fullArchiveContainer: containers.NewInterceptorsContainer(), - shardCoordinator: args.ShardCoordinator, - mainMessenger: args.MainMessenger, - fullArchiveMessenger: args.FullArchiveMessenger, - store: args.Store, - dataPool: args.DataPool, - nodesCoordinator: args.NodesCoordinator, - blockBlackList: args.BlockBlackList, - argInterceptorFactory: argInterceptorFactory, - maxTxNonceDeltaAllowed: args.MaxTxNonceDeltaAllowed, - accounts: args.Accounts, - antifloodHandler: args.AntifloodHandler, - whiteListHandler: args.WhiteListHandler, - whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, - preferredPeersHolder: args.PreferredPeersHolder, - hasher: args.CoreComponents.Hasher(), - requestHandler: args.RequestHandler, - mainPeerShardMapper: args.MainPeerShardMapper, - fullArchivePeerShardMapper: args.FullArchivePeerShardMapper, - hardforkTrigger: args.HardforkTrigger, - nodeOperationMode: args.NodeOperationMode, + mainContainer: containers.NewInterceptorsContainer(), + fullArchiveContainer: containers.NewInterceptorsContainer(), + shardCoordinator: args.ShardCoordinator, + mainMessenger: args.MainMessenger, + fullArchiveMessenger: args.FullArchiveMessenger, + store: args.Store, + dataPool: args.DataPool, + nodesCoordinator: args.NodesCoordinator, + blockBlackList: args.BlockBlackList, + argInterceptorFactory: argInterceptorFactory, + maxTxNonceDeltaAllowed: args.MaxTxNonceDeltaAllowed, + accounts: args.Accounts, + antifloodHandler: args.AntifloodHandler, + whiteListHandler: args.WhiteListHandler, + whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, + preferredPeersHolder: args.PreferredPeersHolder, + hasher: args.CoreComponents.Hasher(), + requestHandler: args.RequestHandler, + mainPeerShardMapper: args.MainPeerShardMapper, + fullArchivePeerShardMapper: args.FullArchivePeerShardMapper, + hardforkTrigger: args.HardforkTrigger, + nodeOperationMode: args.NodeOperationMode, + interceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, + enableEpochsHandler: args.CoreComponents.EnableEpochsHandler(), } icf := &metaInterceptorsContainerFactory{ @@ -195,6 +201,11 @@ func (micf *metaInterceptorsContainerFactory) Create() (process.InterceptorsCont return nil, nil, err } + err = micf.generateEquivalentProofsInterceptors() + if err != nil { + return nil, nil, err + } + return micf.mainContainer, micf.fullArchiveContainer, nil } @@ -253,24 +264,32 @@ func (micf *metaInterceptorsContainerFactory) createOneShardHeaderInterceptor(to } argProcessor := &processor.ArgHdrInterceptorProcessor{ - Headers: micf.dataPool.Headers(), - BlockBlackList: micf.blockBlackList, + Headers: micf.dataPool.Headers(), + BlockBlackList: micf.blockBlackList, + Proofs: micf.dataPool.Proofs(), + EnableEpochsHandler: micf.enableEpochsHandler, } hdrProcessor, err := processor.NewHdrInterceptorProcessor(argProcessor) if err != nil { return nil, err } + interceptedDataVerifier, err := micf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := processInterceptors.NewSingleDataInterceptor( processInterceptors.ArgSingleDataInterceptor{ - Topic: topic, - DataFactory: hdrFactory, - Processor: hdrProcessor, - Throttler: micf.globalThrottler, - AntifloodHandler: micf.antifloodHandler, - WhiteListRequest: micf.whiteListHandler, - CurrentPeerId: micf.mainMessenger.ID(), - PreferredPeersHolder: micf.preferredPeersHolder, + Topic: topic, + DataFactory: hdrFactory, + Processor: hdrProcessor, + Throttler: micf.globalThrottler, + AntifloodHandler: micf.antifloodHandler, + WhiteListRequest: micf.whiteListHandler, + CurrentPeerId: micf.mainMessenger.ID(), + PreferredPeersHolder: micf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -329,6 +348,39 @@ func (micf *metaInterceptorsContainerFactory) generateRewardTxInterceptors() err return micf.addInterceptorsToContainers(keys, interceptorSlice) } +func (micf *metaInterceptorsContainerFactory) generateEquivalentProofsInterceptors() error { + shardC := micf.shardCoordinator + noOfShards := shardC.NumberOfShards() + + keys := make([]string, noOfShards+1) + interceptorSlice := make([]process.Interceptor, noOfShards+1) + + for idx := uint32(0); idx < noOfShards; idx++ { + // equivalent proofs shard topic, to listen for shard proofs, for example: equivalentProofs_0_META + identifierEquivalentProofs := common.EquivalentProofsTopic + shardC.CommunicationIdentifier(idx) + interceptor, err := micf.createOneShardEquivalentProofsInterceptor(identifierEquivalentProofs) + if err != nil { + return err + } + + keys[int(idx)] = identifierEquivalentProofs + interceptorSlice[int(idx)] = interceptor + } + + // equivalent proofs meta all topic, equivalentProofs_ALL + identifierEquivalentProofs := common.EquivalentProofsTopic + shardC.CommunicationIdentifier(core.AllShardId) + + interceptor, err := micf.createOneShardEquivalentProofsInterceptor(identifierEquivalentProofs) + if err != nil { + return err + } + + keys[noOfShards] = identifierEquivalentProofs + interceptorSlice[noOfShards] = interceptor + + return micf.addInterceptorsToContainers(keys, interceptorSlice) +} + // IsInterfaceNil returns true if there is no value under the interface func (micf *metaInterceptorsContainerFactory) IsInterfaceNil() bool { return micf == nil diff --git a/process/factory/interceptorscontainer/metaInterceptorsContainerFactory_test.go b/process/factory/interceptorscontainer/metaInterceptorsContainerFactory_test.go index 87eeb5f91fb..eafb147747a 100644 --- a/process/factory/interceptorscontainer/metaInterceptorsContainerFactory_test.go +++ b/process/factory/interceptorscontainer/metaInterceptorsContainerFactory_test.go @@ -5,6 +5,9 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/p2p" @@ -14,6 +17,8 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" @@ -21,8 +26,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const maxTxNonceDeltaAllowed = 100 @@ -63,7 +66,7 @@ func createMetaDataPools() dataRetriever.PoolsHolder { return &mock.HeadersCacherStub{} }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, TransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() @@ -72,11 +75,14 @@ func createMetaDataPools() dataRetriever.PoolsHolder { return testscommon.NewShardedDataStub() }, TrieNodesCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, RewardTransactionsCalled: func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } return pools @@ -397,6 +403,18 @@ func TestNewMetaInterceptorsContainerFactory_NilPeerSignatureHandler(t *testing. assert.Equal(t, process.ErrNilPeerSignatureHandler, err) } +func TestNewMetaInterceptorsContainerFactory_NilInterceptedDataVerifierFactory(t *testing.T) { + t.Parallel() + + coreComp, cryptoComp := createMockComponentHolders() + args := getArgumentsShard(coreComp, cryptoComp) + args.InterceptedDataVerifierFactory = nil + icf, err := interceptorscontainer.NewMetaInterceptorsContainerFactory(args) + + assert.Nil(t, icf) + assert.Equal(t, process.ErrNilInterceptedDataVerifierFactory, err) +} + func TestNewMetaInterceptorsContainerFactory_InvalidExpiryTimespan(t *testing.T) { t.Parallel() @@ -521,6 +539,8 @@ func TestMetaInterceptorsContainerFactory_CreateTopicsAndRegisterFailure(t *test testCreateMetaTopicShouldFailOnAllMessenger(t, "generatePeerShardInterceptor", common.ConnectionTopic, "") + testCreateMetaTopicShouldFailOnAllMessenger(t, "generateEquivalentProofsInterceptors", common.EquivalentProofsTopic, "") + t.Run("generatePeerAuthenticationInterceptor_main", testCreateMetaTopicShouldFail(common.PeerAuthenticationTopic, "")) } @@ -541,6 +561,7 @@ func testCreateMetaTopicShouldFail(matchStrToErrOnCreate string, matchStrToErrOn } else { args.MainMessenger = createMetaStubTopicHandler(matchStrToErrOnCreate, matchStrToErrOnRegister) } + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} icf, _ := interceptorscontainer.NewMetaInterceptorsContainerFactory(args) mainContainer, fullArchiveConatiner, err := icf.Create() @@ -556,13 +577,15 @@ func TestMetaInterceptorsContainerFactory_CreateShouldWork(t *testing.T) { coreComp, cryptoComp := createMockComponentHolders() args := getArgumentsMeta(coreComp, cryptoComp) + + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} icf, _ := interceptorscontainer.NewMetaInterceptorsContainerFactory(args) mainContainer, fullArchiveContainer, err := icf.Create() + require.Nil(t, err) assert.NotNil(t, mainContainer) assert.NotNil(t, fullArchiveContainer) - assert.Nil(t, err) } func TestMetaInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { @@ -588,6 +611,8 @@ func TestMetaInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { args := getArgumentsMeta(coreComp, cryptoComp) args.ShardCoordinator = shardCoordinator args.NodesCoordinator = nodesCoordinator + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} + icf, err := interceptorscontainer.NewMetaInterceptorsContainerFactory(args) require.Nil(t, err) @@ -604,10 +629,11 @@ func TestMetaInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { numInterceptorsHeartbeatForMetachain := 1 numInterceptorsShardValidatorInfoForMetachain := 1 numInterceptorValidatorInfo := 1 + numInterceptorsEquivalentProofs := noOfShards + 1 totalInterceptors := numInterceptorsMetablock + numInterceptorsShardHeadersForMetachain + numInterceptorsTrieNodes + numInterceptorsTransactionsForMetachain + numInterceptorsUnsignedTxsForMetachain + numInterceptorsMiniBlocksForMetachain + numInterceptorsRewardsTxsForMetachain + numInterceptorsPeerAuthForMetachain + numInterceptorsHeartbeatForMetachain + - numInterceptorsShardValidatorInfoForMetachain + numInterceptorValidatorInfo + numInterceptorsShardValidatorInfoForMetachain + numInterceptorValidatorInfo + numInterceptorsEquivalentProofs assert.Nil(t, err) assert.Equal(t, totalInterceptors, mainContainer.Len()) @@ -637,6 +663,7 @@ func TestMetaInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { args.NodeOperationMode = common.FullArchiveMode args.ShardCoordinator = shardCoordinator args.NodesCoordinator = nodesCoordinator + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} icf, err := interceptorscontainer.NewMetaInterceptorsContainerFactory(args) require.Nil(t, err) @@ -654,10 +681,11 @@ func TestMetaInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { numInterceptorsHeartbeatForMetachain := 1 numInterceptorsShardValidatorInfoForMetachain := 1 numInterceptorValidatorInfo := 1 + numInterceptorsEquivalentProofs := noOfShards + 1 totalInterceptors := numInterceptorsMetablock + numInterceptorsShardHeadersForMetachain + numInterceptorsTrieNodes + numInterceptorsTransactionsForMetachain + numInterceptorsUnsignedTxsForMetachain + numInterceptorsMiniBlocksForMetachain + numInterceptorsRewardsTxsForMetachain + numInterceptorsPeerAuthForMetachain + numInterceptorsHeartbeatForMetachain + - numInterceptorsShardValidatorInfoForMetachain + numInterceptorValidatorInfo + numInterceptorsShardValidatorInfoForMetachain + numInterceptorValidatorInfo + numInterceptorsEquivalentProofs assert.Nil(t, err) assert.Equal(t, totalInterceptors, mainContainer.Len()) @@ -678,34 +706,35 @@ func getArgumentsMeta( cryptoComp *mock.CryptoComponentsMock, ) interceptorscontainer.CommonInterceptorsContainerFactoryArgs { return interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: coreComp, - CryptoComponents: cryptoComp, - Accounts: &stateMock.AccountsStub{}, - ShardCoordinator: mock.NewOneShardCoordinatorMock(), - NodesCoordinator: shardingMocks.NewNodesCoordinatorMock(), - MainMessenger: &mock.TopicHandlerStub{}, - FullArchiveMessenger: &mock.TopicHandlerStub{}, - Store: createMetaStore(), - DataPool: createMetaDataPools(), - MaxTxNonceDeltaAllowed: maxTxNonceDeltaAllowed, - TxFeeHandler: &economicsmocks.EconomicsHandlerMock{}, - BlockBlackList: &testscommon.TimeCacheStub{}, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, - HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, - ValidityAttester: &mock.ValidityAttesterStub{}, - EpochStartTrigger: &mock.EpochStartTriggerStub{}, - WhiteListHandler: &testscommon.WhiteListHandlerStub{}, - WhiteListerVerifiedTxs: &testscommon.WhiteListHandlerStub{}, - AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, - ArgumentsParser: &testscommon.ArgumentParserMock{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - RequestHandler: &testscommon.RequestHandlerStub{}, - PeerSignatureHandler: &mock.PeerSignatureHandlerStub{}, - SignaturesHandler: &mock.SignaturesHandlerStub{}, - HeartbeatExpiryTimespanInSec: 30, - MainPeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, - FullArchivePeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, - HardforkTrigger: &testscommon.HardforkTriggerStub{}, - NodeOperationMode: common.NormalOperation, + CoreComponents: coreComp, + CryptoComponents: cryptoComp, + Accounts: &stateMock.AccountsStub{}, + ShardCoordinator: mock.NewOneShardCoordinatorMock(), + NodesCoordinator: shardingMocks.NewNodesCoordinatorMock(), + MainMessenger: &mock.TopicHandlerStub{}, + FullArchiveMessenger: &mock.TopicHandlerStub{}, + Store: createMetaStore(), + DataPool: createMetaDataPools(), + MaxTxNonceDeltaAllowed: maxTxNonceDeltaAllowed, + TxFeeHandler: &economicsmocks.EconomicsHandlerMock{}, + BlockBlackList: &testscommon.TimeCacheStub{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, + ValidityAttester: &mock.ValidityAttesterStub{}, + EpochStartTrigger: &mock.EpochStartTriggerStub{}, + WhiteListHandler: &testscommon.WhiteListHandlerStub{}, + WhiteListerVerifiedTxs: &testscommon.WhiteListHandlerStub{}, + AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, + ArgumentsParser: &testscommon.ArgumentParserMock{}, + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + RequestHandler: &testscommon.RequestHandlerStub{}, + PeerSignatureHandler: &mock.PeerSignatureHandlerStub{}, + SignaturesHandler: &mock.SignaturesHandlerStub{}, + HeartbeatExpiryTimespanInSec: 30, + MainPeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, + FullArchivePeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, + HardforkTrigger: &testscommon.HardforkTriggerStub{}, + NodeOperationMode: common.NormalOperation, + InterceptedDataVerifierFactory: &mock.InterceptedDataVerifierFactoryMock{}, } } diff --git a/process/factory/interceptorscontainer/shardInterceptorsContainerFactory.go b/process/factory/interceptorscontainer/shardInterceptorsContainerFactory.go index beef288c54c..d144113d30f 100644 --- a/process/factory/interceptorscontainer/shardInterceptorsContainerFactory.go +++ b/process/factory/interceptorscontainer/shardInterceptorsContainerFactory.go @@ -5,6 +5,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/throttler" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" "github.com/multiversx/mx-chain-go/process/factory/containers" @@ -78,6 +81,9 @@ func NewShardInterceptorsContainerFactory( if check.IfNil(args.PeerSignatureHandler) { return nil, process.ErrNilPeerSignatureHandler } + if check.IfNil(args.InterceptedDataVerifierFactory) { + return nil, process.ErrNilInterceptedDataVerifierFactory + } if args.HeartbeatExpiryTimespanInSec < minTimespanDurationInSec { return nil, process.ErrInvalidExpiryTimespan } @@ -101,28 +107,30 @@ func NewShardInterceptorsContainerFactory( } base := &baseInterceptorsContainerFactory{ - mainContainer: containers.NewInterceptorsContainer(), - fullArchiveContainer: containers.NewInterceptorsContainer(), - accounts: args.Accounts, - shardCoordinator: args.ShardCoordinator, - mainMessenger: args.MainMessenger, - fullArchiveMessenger: args.FullArchiveMessenger, - store: args.Store, - dataPool: args.DataPool, - nodesCoordinator: args.NodesCoordinator, - argInterceptorFactory: argInterceptorFactory, - blockBlackList: args.BlockBlackList, - maxTxNonceDeltaAllowed: args.MaxTxNonceDeltaAllowed, - antifloodHandler: args.AntifloodHandler, - whiteListHandler: args.WhiteListHandler, - whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, - preferredPeersHolder: args.PreferredPeersHolder, - hasher: args.CoreComponents.Hasher(), - requestHandler: args.RequestHandler, - mainPeerShardMapper: args.MainPeerShardMapper, - fullArchivePeerShardMapper: args.FullArchivePeerShardMapper, - hardforkTrigger: args.HardforkTrigger, - nodeOperationMode: args.NodeOperationMode, + mainContainer: containers.NewInterceptorsContainer(), + fullArchiveContainer: containers.NewInterceptorsContainer(), + accounts: args.Accounts, + shardCoordinator: args.ShardCoordinator, + mainMessenger: args.MainMessenger, + fullArchiveMessenger: args.FullArchiveMessenger, + store: args.Store, + dataPool: args.DataPool, + nodesCoordinator: args.NodesCoordinator, + argInterceptorFactory: argInterceptorFactory, + blockBlackList: args.BlockBlackList, + maxTxNonceDeltaAllowed: args.MaxTxNonceDeltaAllowed, + antifloodHandler: args.AntifloodHandler, + whiteListHandler: args.WhiteListHandler, + whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, + preferredPeersHolder: args.PreferredPeersHolder, + hasher: args.CoreComponents.Hasher(), + requestHandler: args.RequestHandler, + mainPeerShardMapper: args.MainPeerShardMapper, + fullArchivePeerShardMapper: args.FullArchivePeerShardMapper, + hardforkTrigger: args.HardforkTrigger, + nodeOperationMode: args.NodeOperationMode, + interceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, + enableEpochsHandler: args.CoreComponents.EnableEpochsHandler(), } icf := &shardInterceptorsContainerFactory{ @@ -194,6 +202,11 @@ func (sicf *shardInterceptorsContainerFactory) Create() (process.InterceptorsCon return nil, nil, err } + err = sicf.generateEquivalentProofsInterceptor() + if err != nil { + return nil, nil, err + } + return sicf.mainContainer, sicf.fullArchiveContainer, nil } @@ -235,6 +248,28 @@ func (sicf *shardInterceptorsContainerFactory) generateRewardTxInterceptor() err return sicf.addInterceptorsToContainers(keys, interceptorSlice) } +func (sicf *shardInterceptorsContainerFactory) generateEquivalentProofsInterceptor() error { + shardC := sicf.shardCoordinator + + // equivalent proofs shard topic, for example: equivalentProofs_0_META + identifierEquivalentProofsShardMeta := common.EquivalentProofsTopic + shardC.CommunicationIdentifier(core.MetachainShardId) + + interceptorShardMeta, err := sicf.createOneShardEquivalentProofsInterceptor(identifierEquivalentProofsShardMeta) + if err != nil { + return err + } + + // equivalent proofs ALL topic, to listen for meta proofs, example: equivalentProofs_ALL + identifierEquivalentProofsMetaAll := common.EquivalentProofsTopic + core.CommunicationIdentifierBetweenShards(core.MetachainShardId, core.AllShardId) + + interceptorMetaAll, err := sicf.createOneShardEquivalentProofsInterceptor(identifierEquivalentProofsMetaAll) + if err != nil { + return err + } + + return sicf.addInterceptorsToContainers([]string{identifierEquivalentProofsShardMeta, identifierEquivalentProofsMetaAll}, []process.Interceptor{interceptorShardMeta, interceptorMetaAll}) +} + // IsInterfaceNil returns true if there is no value under the interface func (sicf *shardInterceptorsContainerFactory) IsInterfaceNil() bool { return sicf == nil diff --git a/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go b/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go index 438e1a143cc..b72d32ad037 100644 --- a/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go +++ b/process/factory/interceptorscontainer/shardInterceptorsContainerFactory_test.go @@ -6,7 +6,12 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/versioning" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/common/graceperiod" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process" @@ -15,6 +20,8 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" @@ -25,7 +32,6 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" ) var providedHardforkPubKey = []byte("provided hardfork pub key") @@ -64,13 +70,13 @@ func createShardDataPools() dataRetriever.PoolsHolder { return &mock.HeadersCacherStub{} } pools.MiniBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.PeerChangesBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.MetaBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.UnsignedTransactionsCalled = func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() @@ -79,14 +85,18 @@ func createShardDataPools() dataRetriever.PoolsHolder { return testscommon.NewShardedDataStub() } pools.TrieNodesCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.TrieNodesChunksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.CurrBlockTxsCalled = func() dataRetriever.TransactionCacher { return &mock.TxForCurrentBlockStub{} } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } + return pools } @@ -353,6 +363,18 @@ func TestNewShardInterceptorsContainerFactory_NilValidityAttesterShouldErr(t *te assert.Equal(t, process.ErrNilValidityAttester, err) } +func TestNewShardInterceptorsContainerFactory_NilInterceptedDataVerifierFactory(t *testing.T) { + t.Parallel() + + coreComp, cryptoComp := createMockComponentHolders() + args := getArgumentsShard(coreComp, cryptoComp) + args.InterceptedDataVerifierFactory = nil + icf, err := interceptorscontainer.NewShardInterceptorsContainerFactory(args) + + assert.Nil(t, icf) + assert.Equal(t, process.ErrNilInterceptedDataVerifierFactory, err) +} + func TestNewShardInterceptorsContainerFactory_InvalidChainIDShouldErr(t *testing.T) { t.Parallel() @@ -479,6 +501,8 @@ func TestShardInterceptorsContainerFactory_CreateTopicsAndRegisterFailure(t *tes testCreateShardTopicShouldFailOnAllMessenger(t, "generatePeerShardIntercepto", common.ConnectionTopic, "") + testCreateShardTopicShouldFailOnAllMessenger(t, "generateEquivalentProofsInterceptor", common.EquivalentProofsTopic, "") + t.Run("generatePeerAuthenticationInterceptor_main", testCreateShardTopicShouldFail(common.PeerAuthenticationTopic, "")) } func testCreateShardTopicShouldFailOnAllMessenger(t *testing.T, testNamePrefix string, matchStrToErrOnCreate string, matchStrToErrOnRegister string) { @@ -492,6 +516,7 @@ func testCreateShardTopicShouldFail(matchStrToErrOnCreate string, matchStrToErrO coreComp, cryptoComp := createMockComponentHolders() args := getArgumentsShard(coreComp, cryptoComp) + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} if strings.Contains(t.Name(), "full_archive") { args.NodeOperationMode = common.FullArchiveMode args.FullArchiveMessenger = createShardStubTopicHandler(matchStrToErrOnCreate, matchStrToErrOnRegister) @@ -558,14 +583,15 @@ func TestShardInterceptorsContainerFactory_CreateShouldWork(t *testing.T) { }, } args.WhiteListerVerifiedTxs = &testscommon.WhiteListHandlerStub{} + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} icf, _ := interceptorscontainer.NewShardInterceptorsContainerFactory(args) mainContainer, fullArchiveContainer, err := icf.Create() + require.Nil(t, err) assert.NotNil(t, mainContainer) assert.NotNil(t, fullArchiveContainer) - assert.Nil(t, err) } func TestShardInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { @@ -593,6 +619,7 @@ func TestShardInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { args.ShardCoordinator = shardCoordinator args.NodesCoordinator = nodesCoordinator args.PreferredPeersHolder = &p2pmocks.PeersHolderStub{} + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} icf, _ := interceptorscontainer.NewShardInterceptorsContainerFactory(args) @@ -609,9 +636,11 @@ func TestShardInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { numInterceptorHeartbeat := 1 numInterceptorsShardValidatorInfo := 1 numInterceptorValidatorInfo := 1 + numInterceptorEquivalentProofs := 2 totalInterceptors := numInterceptorTxs + numInterceptorsUnsignedTxs + numInterceptorsRewardTxs + numInterceptorHeaders + numInterceptorMiniBlocks + numInterceptorMetachainHeaders + numInterceptorTrieNodes + - numInterceptorPeerAuth + numInterceptorHeartbeat + numInterceptorsShardValidatorInfo + numInterceptorValidatorInfo + numInterceptorPeerAuth + numInterceptorHeartbeat + numInterceptorsShardValidatorInfo + numInterceptorValidatorInfo + + numInterceptorEquivalentProofs assert.Nil(t, err) assert.Equal(t, totalInterceptors, mainContainer.Len()) @@ -641,6 +670,7 @@ func TestShardInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { args.ShardCoordinator = shardCoordinator args.NodesCoordinator = nodesCoordinator args.PreferredPeersHolder = &p2pmocks.PeersHolderStub{} + args.InterceptedDataVerifierFactory = &mock.InterceptedDataVerifierFactoryMock{} icf, _ := interceptorscontainer.NewShardInterceptorsContainerFactory(args) @@ -657,9 +687,11 @@ func TestShardInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { numInterceptorHeartbeat := 1 numInterceptorsShardValidatorInfo := 1 numInterceptorValidatorInfo := 1 + numInterceptorEquivalentProofs := 2 totalInterceptors := numInterceptorTxs + numInterceptorsUnsignedTxs + numInterceptorsRewardTxs + numInterceptorHeaders + numInterceptorMiniBlocks + numInterceptorMetachainHeaders + numInterceptorTrieNodes + - numInterceptorPeerAuth + numInterceptorHeartbeat + numInterceptorsShardValidatorInfo + numInterceptorValidatorInfo + numInterceptorPeerAuth + numInterceptorHeartbeat + numInterceptorsShardValidatorInfo + numInterceptorValidatorInfo + + numInterceptorEquivalentProofs assert.Nil(t, err) assert.Equal(t, totalInterceptors, mainContainer.Len()) @@ -668,6 +700,7 @@ func TestShardInterceptorsContainerFactory_With4ShardsShouldWork(t *testing.T) { } func createMockComponentHolders() (*mock.CoreComponentsMock, *mock.CryptoComponentsMock) { + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) coreComponents := &mock.CoreComponentsMock{ IntMarsh: &mock.MarshalizerMock{}, TxMarsh: &mock.MarshalizerMock{}, @@ -681,10 +714,11 @@ func createMockComponentHolders() (*mock.CoreComponentsMock, *mock.CryptoCompone MinTransactionVersionCalled: func() uint32 { return 1 }, - EpochNotifierField: &epochNotifier.EpochNotifierStub{}, - TxVersionCheckField: versioning.NewTxVersionChecker(1), - HardforkTriggerPubKeyField: providedHardforkPubKey, - EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + EpochNotifierField: &epochNotifier.EpochNotifierStub{}, + TxVersionCheckField: versioning.NewTxVersionChecker(1), + HardforkTriggerPubKeyField: providedHardforkPubKey, + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + EpochChangeGracePeriodHandlerField: gracePeriod, } multiSigner := cryptoMocks.NewMultiSigner() cryptoComponents := &mock.CryptoComponentsMock{ @@ -703,34 +737,35 @@ func getArgumentsShard( cryptoComp *mock.CryptoComponentsMock, ) interceptorscontainer.CommonInterceptorsContainerFactoryArgs { return interceptorscontainer.CommonInterceptorsContainerFactoryArgs{ - CoreComponents: coreComp, - CryptoComponents: cryptoComp, - Accounts: &stateMock.AccountsStub{}, - ShardCoordinator: mock.NewOneShardCoordinatorMock(), - NodesCoordinator: shardingMocks.NewNodesCoordinatorMock(), - MainMessenger: &mock.TopicHandlerStub{}, - FullArchiveMessenger: &mock.TopicHandlerStub{}, - Store: createShardStore(), - DataPool: createShardDataPools(), - MaxTxNonceDeltaAllowed: maxTxNonceDeltaAllowed, - TxFeeHandler: &economicsmocks.EconomicsHandlerMock{}, - BlockBlackList: &testscommon.TimeCacheStub{}, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, - HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, - SizeCheckDelta: 0, - ValidityAttester: &mock.ValidityAttesterStub{}, - EpochStartTrigger: &mock.EpochStartTriggerStub{}, - AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, - WhiteListHandler: &testscommon.WhiteListHandlerStub{}, - WhiteListerVerifiedTxs: &testscommon.WhiteListHandlerStub{}, - ArgumentsParser: &testscommon.ArgumentParserMock{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - RequestHandler: &testscommon.RequestHandlerStub{}, - PeerSignatureHandler: &mock.PeerSignatureHandlerStub{}, - SignaturesHandler: &mock.SignaturesHandlerStub{}, - HeartbeatExpiryTimespanInSec: 30, - MainPeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, - FullArchivePeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, - HardforkTrigger: &testscommon.HardforkTriggerStub{}, + CoreComponents: coreComp, + CryptoComponents: cryptoComp, + Accounts: &stateMock.AccountsStub{}, + ShardCoordinator: mock.NewOneShardCoordinatorMock(), + NodesCoordinator: shardingMocks.NewNodesCoordinatorMock(), + MainMessenger: &mock.TopicHandlerStub{}, + FullArchiveMessenger: &mock.TopicHandlerStub{}, + Store: createShardStore(), + DataPool: createShardDataPools(), + MaxTxNonceDeltaAllowed: maxTxNonceDeltaAllowed, + TxFeeHandler: &economicsmocks.EconomicsHandlerMock{}, + BlockBlackList: &testscommon.TimeCacheStub{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, + SizeCheckDelta: 0, + ValidityAttester: &mock.ValidityAttesterStub{}, + EpochStartTrigger: &mock.EpochStartTriggerStub{}, + AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, + WhiteListHandler: &testscommon.WhiteListHandlerStub{}, + WhiteListerVerifiedTxs: &testscommon.WhiteListHandlerStub{}, + ArgumentsParser: &testscommon.ArgumentParserMock{}, + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + RequestHandler: &testscommon.RequestHandlerStub{}, + PeerSignatureHandler: &mock.PeerSignatureHandlerStub{}, + SignaturesHandler: &mock.SignaturesHandlerStub{}, + HeartbeatExpiryTimespanInSec: 30, + MainPeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, + FullArchivePeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, + HardforkTrigger: &testscommon.HardforkTriggerStub{}, + InterceptedDataVerifierFactory: &mock.InterceptedDataVerifierFactoryMock{}, } } diff --git a/process/factory/shard/intermediateProcessorsContainerFactory_test.go b/process/factory/shard/intermediateProcessorsContainerFactory_test.go index eda43c07d4e..8b727710c79 100644 --- a/process/factory/shard/intermediateProcessorsContainerFactory_test.go +++ b/process/factory/shard/intermediateProcessorsContainerFactory_test.go @@ -3,6 +3,8 @@ package shard_test import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" @@ -10,13 +12,13 @@ import ( "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" txExecOrderStub "github.com/multiversx/mx-chain-go/testscommon/common" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" ) func createDataPools() dataRetriever.PoolsHolder { @@ -28,13 +30,13 @@ func createDataPools() dataRetriever.PoolsHolder { return &mock.HeadersCacherStub{} } pools.MiniBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.PeerChangesBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.MetaBlocksCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.UnsignedTransactionsCalled = func() dataRetriever.ShardedDataCacherNotifier { return testscommon.NewShardedDataStub() @@ -43,7 +45,7 @@ func createDataPools() dataRetriever.PoolsHolder { return testscommon.NewShardedDataStub() } pools.TrieNodesCalled = func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() } pools.CurrBlockTxsCalled = func() dataRetriever.TransactionCacher { return &mock.TxForCurrentBlockStub{} diff --git a/process/headerCheck/common.go b/process/headerCheck/common.go index 01946580d87..353c112e501 100644 --- a/process/headerCheck/common.go +++ b/process/headerCheck/common.go @@ -3,20 +3,24 @@ package headerCheck import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" ) // ComputeConsensusGroup will compute the consensus group that assembled the provided block -func ComputeConsensusGroup(header data.HeaderHandler, nodesCoordinator nodesCoordinator.NodesCoordinator) (validatorsGroup []nodesCoordinator.Validator, err error) { +func ComputeConsensusGroup(header data.HeaderHandler, nodesCoordinator nodesCoordinator.NodesCoordinator) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { if check.IfNil(header) { - return nil, process.ErrNilHeaderHandler + return nil, nil, process.ErrNilHeaderHandler } if check.IfNil(nodesCoordinator) { - return nil, process.ErrNilNodesCoordinator + return nil, nil, process.ErrNilNodesCoordinator } prevRandSeed := header.GetPrevRandSeed() + if prevRandSeed == nil { + return nil, nil, process.ErrNilPrevRandSeed + } // TODO: change here with an activation flag if start of epoch block needs to be validated by the new epoch nodes epoch := header.GetEpoch() diff --git a/process/headerCheck/common_test.go b/process/headerCheck/common_test.go index 0961b7f2a20..8924327fcbd 100644 --- a/process/headerCheck/common_test.go +++ b/process/headerCheck/common_test.go @@ -5,10 +5,11 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/assert" ) func TestComputeConsensusGroup(t *testing.T) { @@ -16,14 +17,15 @@ func TestComputeConsensusGroup(t *testing.T) { t.Run("nil header should error", func(t *testing.T) { nodesCoordinatorInstance := shardingMocks.NewNodesCoordinatorMock() - nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { + nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { assert.Fail(t, "should have not called ComputeValidatorsGroupCalled") - return nil, nil + return nil, nil, nil } - vGroup, err := ComputeConsensusGroup(nil, nodesCoordinatorInstance) + leader, vGroup, err := ComputeConsensusGroup(nil, nodesCoordinatorInstance) assert.Equal(t, process.ErrNilHeaderHandler, err) assert.Nil(t, vGroup) + assert.Nil(t, leader) }) t.Run("nil nodes coordinator should error", func(t *testing.T) { header := &block.Header{ @@ -34,9 +36,10 @@ func TestComputeConsensusGroup(t *testing.T) { PrevRandSeed: []byte("prev rand seed"), } - vGroup, err := ComputeConsensusGroup(header, nil) + leader, vGroup, err := ComputeConsensusGroup(header, nil) assert.Equal(t, process.ErrNilNodesCoordinator, err) assert.Nil(t, vGroup) + assert.Nil(t, leader) }) t.Run("should work for a random block", func(t *testing.T) { header := &block.Header{ @@ -52,18 +55,19 @@ func TestComputeConsensusGroup(t *testing.T) { validatorGroup := []nodesCoordinator.Validator{validator1, validator2} nodesCoordinatorInstance := shardingMocks.NewNodesCoordinatorMock() - nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { + nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { assert.Equal(t, header.PrevRandSeed, randomness) assert.Equal(t, header.Round, round) assert.Equal(t, header.ShardID, shardId) assert.Equal(t, header.Epoch, epoch) - return validatorGroup, nil + return validator1, validatorGroup, nil } - vGroup, err := ComputeConsensusGroup(header, nodesCoordinatorInstance) + leader, vGroup, err := ComputeConsensusGroup(header, nodesCoordinatorInstance) assert.Nil(t, err) assert.Equal(t, validatorGroup, vGroup) + assert.Equal(t, validator1, leader) }) t.Run("should work for a start of epoch block", func(t *testing.T) { header := &block.Header{ @@ -80,18 +84,19 @@ func TestComputeConsensusGroup(t *testing.T) { validatorGroup := []nodesCoordinator.Validator{validator1, validator2} nodesCoordinatorInstance := shardingMocks.NewNodesCoordinatorMock() - nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { + nodesCoordinatorInstance.ComputeValidatorsGroupCalled = func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { assert.Equal(t, header.PrevRandSeed, randomness) assert.Equal(t, header.Round, round) assert.Equal(t, header.ShardID, shardId) assert.Equal(t, header.Epoch-1, epoch) - return validatorGroup, nil + return validator1, validatorGroup, nil } - vGroup, err := ComputeConsensusGroup(header, nodesCoordinatorInstance) + leader, vGroup, err := ComputeConsensusGroup(header, nodesCoordinatorInstance) assert.Nil(t, err) assert.Equal(t, validatorGroup, vGroup) + assert.Equal(t, validator1, leader) }) } diff --git a/process/headerCheck/errors.go b/process/headerCheck/errors.go index e0d4123ae2b..f565a56f7c3 100644 --- a/process/headerCheck/errors.go +++ b/process/headerCheck/errors.go @@ -2,13 +2,6 @@ package headerCheck import "errors" -// ErrNotEnoughSignatures signals that a block is not signed by at least the minimum number of validators from -// the consensus group -var ErrNotEnoughSignatures = errors.New("not enough signatures in block") - -// ErrWrongSizeBitmap signals that the provided bitmap's length is bigger than the one that was required -var ErrWrongSizeBitmap = errors.New("wrong size bitmap has been provided") - // ErrInvalidReferenceChainID signals that the provided reference chain ID is not valid var ErrInvalidReferenceChainID = errors.New("invalid reference Chain ID provided") @@ -23,3 +16,12 @@ var ErrIndexOutOfBounds = errors.New("index is out of bounds") // ErrIndexNotSelected signals that the given index is not selected var ErrIndexNotSelected = errors.New("index is not selected") + +// ErrProofShardMismatch signals that the proof shard does not match the header shard +var ErrProofShardMismatch = errors.New("proof shard mismatch") + +// ErrProofHeaderHashMismatch signals that the proof header hash does not match the header hash +var ErrProofHeaderHashMismatch = errors.New("proof header hash mismatch") + +// ErrProofNotExpected signals that the proof is not expected +var ErrProofNotExpected = errors.New("proof not expected") diff --git a/process/headerCheck/headerSignatureVerify.go b/process/headerCheck/headerSignatureVerify.go index 308af919366..8adf9503074 100644 --- a/process/headerCheck/headerSignatureVerify.go +++ b/process/headerCheck/headerSignatureVerify.go @@ -1,7 +1,8 @@ package headerCheck import ( - "math/bits" + "fmt" + "time" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" @@ -9,12 +10,17 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" + + "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" - logger "github.com/multiversx/mx-chain-logger-go" ) +const headerWaitDelayAtTransition = 50 * time.Millisecond + var _ process.InterceptedHeaderSigVerifier = (*HeaderSigVerifier)(nil) var log = logger.GetOrCreate("process/headerCheck") @@ -28,6 +34,10 @@ type ArgsHeaderSigVerifier struct { SingleSigVerifier crypto.SingleSigner KeyGen crypto.KeyGenerator FallbackHeaderValidator process.FallbackHeaderValidator + EnableEpochsHandler common.EnableEpochsHandler + HeadersPool dataRetriever.HeadersPool + ProofsPool dataRetriever.ProofsPool + StorageService dataRetriever.StorageService } // HeaderSigVerifier is component used to check if a header is valid @@ -39,6 +49,10 @@ type HeaderSigVerifier struct { singleSigVerifier crypto.SingleSigner keyGen crypto.KeyGenerator fallbackHeaderValidator process.FallbackHeaderValidator + enableEpochsHandler common.EnableEpochsHandler + headersPool dataRetriever.HeadersPool + proofsPool dataRetriever.ProofsPool + storageService dataRetriever.StorageService } // NewHeaderSigVerifier will create a new instance of HeaderSigVerifier @@ -56,6 +70,10 @@ func NewHeaderSigVerifier(arguments *ArgsHeaderSigVerifier) (*HeaderSigVerifier, singleSigVerifier: arguments.SingleSigVerifier, keyGen: arguments.KeyGen, fallbackHeaderValidator: arguments.FallbackHeaderValidator, + enableEpochsHandler: arguments.EnableEpochsHandler, + headersPool: arguments.HeadersPool, + proofsPool: arguments.ProofsPool, + storageService: arguments.StorageService, }, nil } @@ -91,6 +109,18 @@ func checkArgsHeaderSigVerifier(arguments *ArgsHeaderSigVerifier) error { if check.IfNil(arguments.FallbackHeaderValidator) { return process.ErrNilFallbackHeaderValidator } + if check.IfNil(arguments.EnableEpochsHandler) { + return process.ErrNilEnableEpochsHandler + } + if check.IfNil(arguments.HeadersPool) { + return process.ErrNilHeadersDataPool + } + if check.IfNil(arguments.ProofsPool) { + return process.ErrNilProofsPool + } + if check.IfNil(arguments.StorageService) { + return process.ErrNilStorageService + } return nil } @@ -109,54 +139,117 @@ func isIndexInBitmap(index uint16, bitmap []byte) error { return nil } -func (hsv *HeaderSigVerifier) getConsensusSigners(header data.HeaderHandler) ([][]byte, error) { - randSeed := header.GetPrevRandSeed() - bitmap := header.GetPubKeysBitmap() - if len(bitmap) == 0 { +func (hsv *HeaderSigVerifier) getConsensusSignersForEquivalentProofs(proof data.HeaderProofHandler) ([][]byte, error) { + if check.IfNil(proof) { + return nil, process.ErrNilHeaderProof + } + if !hsv.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, proof.GetHeaderEpoch()) { + return nil, process.ErrUnexpectedHeaderProof + } + + // TODO: remove if start of epochForConsensus block needs to be validated by the new epochForConsensus nodes + epochForConsensus := common.GetEpochForConsensus(proof) + + consensusPubKeys, err := hsv.nodesCoordinator.GetAllEligibleValidatorsPublicKeysForShard( + epochForConsensus, + proof.GetHeaderShardId(), + ) + if err != nil { + return nil, err + } + + shouldApplyFallbackValidation := hsv.fallbackHeaderValidator.ShouldApplyFallbackValidationForHeaderWith( + proof.GetHeaderShardId(), + proof.GetIsStartOfEpoch(), + proof.GetHeaderRound(), + proof.GetHeaderHash(), + ) + + err = common.IsConsensusBitmapValid( + log, + consensusPubKeys, + proof.GetPubKeysBitmap(), + shouldApplyFallbackValidation, + ) + if err != nil { + return nil, err + } + + return getPubKeySigners(consensusPubKeys, proof.GetPubKeysBitmap()), nil +} + +func (hsv *HeaderSigVerifier) getConsensusSigners( + randSeed []byte, + shardID uint32, + epoch uint32, + startOfEpochBlock bool, + round uint64, + prevHash []byte, + pubKeysBitmap []byte, +) ([][]byte, error) { + if len(pubKeysBitmap) == 0 { return nil, process.ErrNilPubKeysBitmap } - if bitmap[0]&1 == 0 { - return nil, process.ErrBlockProposerSignatureMissing + + if !hsv.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, epoch) { + if pubKeysBitmap[0]&1 == 0 { + return nil, process.ErrBlockProposerSignatureMissing + } } // TODO: remove if start of epochForConsensus block needs to be validated by the new epochForConsensus nodes - epochForConsensus := header.GetEpoch() - if header.IsStartOfEpochBlock() && epochForConsensus > 0 { + epochForConsensus := epoch + if startOfEpochBlock && epochForConsensus > 0 { epochForConsensus = epochForConsensus - 1 } - consensusPubKeys, err := hsv.nodesCoordinator.GetConsensusValidatorsPublicKeys( + _, consensusPubKeys, err := hsv.nodesCoordinator.GetConsensusValidatorsPublicKeys( randSeed, - header.GetRound(), - header.GetShardID(), + round, + shardID, epochForConsensus, ) if err != nil { return nil, err } - err = hsv.verifyConsensusSize(consensusPubKeys, header) + shouldApplyFallbackValidation := hsv.fallbackHeaderValidator.ShouldApplyFallbackValidationForHeaderWith( + shardID, + startOfEpochBlock, + round, + prevHash, + ) + + err = common.IsConsensusBitmapValid( + log, + consensusPubKeys, + pubKeysBitmap, + shouldApplyFallbackValidation, + ) if err != nil { return nil, err } + return getPubKeySigners(consensusPubKeys, pubKeysBitmap), nil +} + +func getPubKeySigners(consensusPubKeys []string, pubKeysBitmap []byte) [][]byte { pubKeysSigners := make([][]byte, 0, len(consensusPubKeys)) for i := range consensusPubKeys { - err = isIndexInBitmap(uint16(i), bitmap) + err := isIndexInBitmap(uint16(i), pubKeysBitmap) if err != nil { continue } pubKeysSigners = append(pubKeysSigners, []byte(consensusPubKeys[i])) } - return pubKeysSigners, nil + return pubKeysSigners } // VerifySignature will check if signature is correct func (hsv *HeaderSigVerifier) VerifySignature(header data.HeaderHandler) error { - multiSigVerifier, err := hsv.multiSigContainer.GetMultiSigner(header.GetEpoch()) - if err != nil { - return err + if hsv.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + return nil } headerCopy, err := hsv.copyHeaderWithoutSig(header) @@ -169,52 +262,112 @@ func (hsv *HeaderSigVerifier) VerifySignature(header data.HeaderHandler) error { return err } - pubKeysSigners, err := hsv.getConsensusSigners(header) + bitmap := header.GetPubKeysBitmap() + sig := header.GetSignature() + return hsv.VerifySignatureForHash(headerCopy, hash, bitmap, sig) +} + +// VerifySignatureForHash will check if signature is correct for the provided hash +func (hsv *HeaderSigVerifier) VerifySignatureForHash(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error { + multiSigVerifier, err := hsv.multiSigContainer.GetMultiSigner(header.GetEpoch()) if err != nil { return err } - return multiSigVerifier.VerifyAggregatedSig(pubKeysSigners, hash, header.GetSignature()) + randSeed := header.GetPrevRandSeed() + if randSeed == nil { + return process.ErrNilPrevRandSeed + } + pubKeysSigners, err := hsv.getConsensusSigners( + randSeed, + header.GetShardID(), + header.GetEpoch(), + header.IsStartOfEpochBlock(), + header.GetRound(), + header.GetPrevHash(), + pubkeysBitmap, + ) + if err != nil { + return err + } + + return multiSigVerifier.VerifyAggregatedSig(pubKeysSigners, hash, signature) } -func (hsv *HeaderSigVerifier) verifyConsensusSize(consensusPubKeys []string, header data.HeaderHandler) error { - consensusSize := len(consensusPubKeys) - bitmap := header.GetPubKeysBitmap() +func (hsv *HeaderSigVerifier) getHeaderForProofAtTransition(proof data.HeaderProofHandler) (data.HeaderHandler, error) { + var header data.HeaderHandler + var err error + + for { + header, err = process.GetHeader(proof.GetHeaderHash(), hsv.headersPool, hsv.storageService, hsv.marshalizer, proof.GetHeaderShardId()) + if err == nil { + break + } - expectedBitmapSize := consensusSize / 8 - if consensusSize%8 != 0 { - expectedBitmapSize++ + log.Debug("getHeaderForProofAtTransition: failed to get header, will wait and try again", + "headerHash", proof.GetHeaderHash(), + "error", err.Error(), + ) + + time.Sleep(headerWaitDelayAtTransition) + } + + return header, nil +} + +func (hsv *HeaderSigVerifier) verifyHeaderProofAtTransition(proof data.HeaderProofHandler) error { + if check.IfNil(proof) { + return process.ErrNilHeaderProof } - if len(bitmap) != expectedBitmapSize { - log.Debug("wrong size bitmap", - "expected number of bytes", expectedBitmapSize, - "actual", len(bitmap)) - return ErrWrongSizeBitmap + header, err := hsv.getHeaderForProofAtTransition(proof) + if err != nil { + return err } - numOfOnesInBitmap := 0 - for index := range bitmap { - numOfOnesInBitmap += bits.OnesCount8(bitmap[index]) + consensusPubKeys, err := hsv.getConsensusSigners( + header.GetPrevRandSeed(), + proof.GetHeaderShardId(), + proof.GetHeaderEpoch(), + proof.GetIsStartOfEpoch(), + proof.GetHeaderRound(), + proof.GetHeaderHash(), + proof.GetPubKeysBitmap()) + if err != nil { + return err } - minNumRequiredSignatures := core.GetPBFTThreshold(consensusSize) - if hsv.fallbackHeaderValidator.ShouldApplyFallbackValidation(header) { - minNumRequiredSignatures = core.GetPBFTFallbackThreshold(consensusSize) - log.Warn("HeaderSigVerifier.verifyConsensusSize: fallback validation has been applied", - "minimum number of signatures required", minNumRequiredSignatures, - "actual number of signatures in bitmap", numOfOnesInBitmap, - ) + multiSigVerifier, err := hsv.multiSigContainer.GetMultiSigner(proof.GetHeaderEpoch()) + if err != nil { + return err } - if numOfOnesInBitmap >= minNumRequiredSignatures { - return nil + return multiSigVerifier.VerifyAggregatedSig(consensusPubKeys, proof.GetHeaderHash(), proof.GetAggregatedSignature()) +} + +// VerifyHeaderProof checks if the proof is correct for the header +func (hsv *HeaderSigVerifier) VerifyHeaderProof(proofHandler data.HeaderProofHandler) error { + if check.IfNil(proofHandler) { + return process.ErrNilHeaderProof + } + if !hsv.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, proofHandler.GetHeaderEpoch()) { + return fmt.Errorf("%w for flag %s", process.ErrFlagNotActive, common.AndromedaFlag) + } + + if common.IsEpochStartProofForFlagActivation(proofHandler, hsv.enableEpochsHandler) { + return hsv.verifyHeaderProofAtTransition(proofHandler) } - log.Debug("not enough signatures", - "minimum expected", minNumRequiredSignatures, - "actual", numOfOnesInBitmap) + multiSigVerifier, err := hsv.multiSigContainer.GetMultiSigner(proofHandler.GetHeaderEpoch()) + if err != nil { + return err + } - return ErrNotEnoughSignatures + consensusPubKeys, err := hsv.getConsensusSignersForEquivalentProofs(proofHandler) + if err != nil { + return err + } + + return multiSigVerifier.VerifyAggregatedSig(consensusPubKeys, proofHandler.GetHeaderHash(), proofHandler.GetAggregatedSignature()) } // VerifyRandSeed will check if rand seed is correct @@ -282,7 +435,15 @@ func (hsv *HeaderSigVerifier) IsInterfaceNil() bool { func (hsv *HeaderSigVerifier) verifyRandSeed(leaderPubKey crypto.PublicKey, header data.HeaderHandler) error { prevRandSeed := header.GetPrevRandSeed() + if prevRandSeed == nil { + return process.ErrNilPrevRandSeed + } + randSeed := header.GetRandSeed() + if randSeed == nil { + return process.ErrNilRandSeed + } + return hsv.singleSigVerifier.Verify(leaderPubKey, prevRandSeed, randSeed) } @@ -301,13 +462,11 @@ func (hsv *HeaderSigVerifier) verifyLeaderSignature(leaderPubKey crypto.PublicKe } func (hsv *HeaderSigVerifier) getLeader(header data.HeaderHandler) (crypto.PublicKey, error) { - headerConsensusGroup, err := ComputeConsensusGroup(header, hsv.nodesCoordinator) + leader, _, err := ComputeConsensusGroup(header, hsv.nodesCoordinator) if err != nil { return nil, err } - - leaderPubKeyValidator := headerConsensusGroup[0] - return hsv.keyGen.PublicKeyFromByteArray(leaderPubKeyValidator.PubKey()) + return hsv.keyGen.PublicKeyFromByteArray(leader.PubKey()) } func (hsv *HeaderSigVerifier) copyHeaderWithoutSig(header data.HeaderHandler) (data.HeaderHandler, error) { @@ -322,9 +481,11 @@ func (hsv *HeaderSigVerifier) copyHeaderWithoutSig(header data.HeaderHandler) (d return nil, err } - err = headerCopy.SetLeaderSignature(nil) - if err != nil { - return nil, err + if !hsv.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + err = headerCopy.SetLeaderSignature(nil) + if err != nil { + return nil, err + } } return headerCopy, nil diff --git a/process/headerCheck/headerSignatureVerify_test.go b/process/headerCheck/headerSignatureVerify_test.go index f89b8cf90ca..a3ea9f874f2 100644 --- a/process/headerCheck/headerSignatureVerify_test.go +++ b/process/headerCheck/headerSignatureVerify_test.go @@ -3,32 +3,68 @@ package headerCheck import ( "bytes" "errors" + "strconv" + "strings" "testing" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" dataBlock "github.com/multiversx/mx-chain-core-go/data/block" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" + "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/require" + testscommonStorage "github.com/multiversx/mx-chain-go/testscommon/storage" ) const defaultChancesSelection = 1 +var expectedErr = errors.New("expected error") + func createHeaderSigVerifierArgs() *ArgsHeaderSigVerifier { + v1, _ := nodesCoordinator.NewValidator([]byte("pubKey1"), 1, defaultChancesSelection) + v2, _ := nodesCoordinator.NewValidator([]byte("pubKey2"), 1, defaultChancesSelection) return &ArgsHeaderSigVerifier{ - Marshalizer: &mock.MarshalizerMock{}, - Hasher: &hashingMocks.HasherMock{}, - NodesCoordinator: &shardingMocks.NodesCoordinatorMock{}, - MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(cryptoMocks.NewMultiSigner()), - SingleSigVerifier: &mock.SignerMock{}, - KeyGen: &mock.SingleSignKeyGenMock{}, + Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + NodesCoordinator: &shardingMocks.NodesCoordinatorMock{ + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { + return v1, []nodesCoordinator.Validator{v1, v2}, nil + }, + GetAllEligibleValidatorsPublicKeysForShardCalled: func(epoch uint32, shardID uint32) ([]string, error) { + return []string{"pubKey1", "pubKey2"}, nil + }, + }, + MultiSigContainer: cryptoMocks.NewMultiSignerContainerMock(cryptoMocks.NewMultiSigner()), + SingleSigVerifier: &mock.SignerMock{}, + KeyGen: &mock.SingleSignKeyGenMock{ + PublicKeyFromByteArrayCalled: func(b []byte) (key crypto.PublicKey, err error) { + return &mock.SingleSignPublicKey{}, nil + }, + }, FallbackHeaderValidator: &testscommon.FallBackHeaderValidatorStub{}, + EnableEpochsHandler: enableEpochsHandlerMock.NewEnableEpochsHandlerStub(), + HeadersPool: &mock.HeadersCacherStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return &dataBlock.Header{ + PrevRandSeed: []byte("prevRandSeed"), + }, nil + }, + }, + ProofsPool: &dataRetrieverMocks.ProofsPoolMock{}, + StorageService: &genericMocks.ChainStorerMock{}, } } @@ -107,6 +143,72 @@ func TestNewHeaderSigVerifier_NilSingleSigShouldErr(t *testing.T) { require.Equal(t, process.ErrNilSingleSigner, err) } +func TestNewHeaderSigVerifier_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + args := createHeaderSigVerifierArgs() + args.EnableEpochsHandler = nil + hdrSigVerifier, err := NewHeaderSigVerifier(args) + + require.Nil(t, hdrSigVerifier) + require.Equal(t, process.ErrNilEnableEpochsHandler, err) +} + +func TestNewHeaderSigVerifier_NilMultiSigContainerShouldErr(t *testing.T) { + t.Parallel() + + args := createHeaderSigVerifierArgs() + args.MultiSigContainer = nil + hdrSigVerifier, err := NewHeaderSigVerifier(args) + + require.Nil(t, hdrSigVerifier) + require.Equal(t, process.ErrNilMultiSignerContainer, err) +} + +func TestNewHeaderSigVerifier_NilFallbackHeaderValidatorShouldErr(t *testing.T) { + t.Parallel() + + args := createHeaderSigVerifierArgs() + args.FallbackHeaderValidator = nil + hdrSigVerifier, err := NewHeaderSigVerifier(args) + + require.Nil(t, hdrSigVerifier) + require.Equal(t, process.ErrNilFallbackHeaderValidator, err) +} + +func TestNewHeaderSigVerifier_NilHeadersPoolShouldErr(t *testing.T) { + t.Parallel() + + args := createHeaderSigVerifierArgs() + args.HeadersPool = nil + hdrSigVerifier, err := NewHeaderSigVerifier(args) + + require.Nil(t, hdrSigVerifier) + require.Equal(t, process.ErrNilHeadersDataPool, err) +} + +func TestNewHeaderSigVerifier_NilProofsPoolShouldErr(t *testing.T) { + t.Parallel() + + args := createHeaderSigVerifierArgs() + args.ProofsPool = nil + hdrSigVerifier, err := NewHeaderSigVerifier(args) + + require.Nil(t, hdrSigVerifier) + require.Equal(t, process.ErrNilProofsPool, err) +} + +func TestNewHeaderSigVerifier_NilStorageServiceShouldErr(t *testing.T) { + t.Parallel() + + args := createHeaderSigVerifierArgs() + args.StorageService = nil + hdrSigVerifier, err := NewHeaderSigVerifier(args) + + require.Nil(t, hdrSigVerifier) + require.Equal(t, process.ErrNilStorageService, err) +} + func TestNewHeaderSigVerifier_OkValsShouldWork(t *testing.T) { t.Parallel() @@ -123,10 +225,13 @@ func TestHeaderSigVerifier_VerifySignatureNilPrevRandSeedShouldErr(t *testing.T) args := createHeaderSigVerifierArgs() hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + PrevRandSeed: nil, + RandSeed: []byte("rand seed"), + } err := hdrSigVerifier.VerifyRandSeed(header) - require.Equal(t, nodesCoordinator.ErrNilRandomness, err) + require.Equal(t, process.ErrNilPrevRandSeed, err) } func TestHeaderSigVerifier_VerifyRandSeedOk(t *testing.T) { @@ -149,14 +254,17 @@ func TestHeaderSigVerifier_VerifyRandSeedOk(t *testing.T) { pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + PrevRandSeed: []byte("prev rand seed"), + RandSeed: []byte("rand seed"), + } err := hdrSigVerifier.VerifyRandSeed(header) require.Nil(t, err) @@ -184,14 +292,17 @@ func TestHeaderSigVerifier_VerifyRandSeedShouldErrWhenVerificationFails(t *testi pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + } err := hdrSigVerifier.VerifyRandSeed(header) require.Equal(t, localError, err) @@ -203,10 +314,13 @@ func TestHeaderSigVerifier_VerifyRandSeedAndLeaderSignatureNilRandomnessShouldEr args := createHeaderSigVerifierArgs() hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: nil, + PrevRandSeed: []byte("prev rand seed"), + } err := hdrSigVerifier.VerifyRandSeedAndLeaderSignature(header) - require.Equal(t, nodesCoordinator.ErrNilRandomness, err) + require.Equal(t, process.ErrNilRandSeed, err) } func TestHeaderSigVerifier_VerifyRandSeedAndLeaderSignatureVerifyShouldErrWhenValidationFails(t *testing.T) { @@ -230,14 +344,17 @@ func TestHeaderSigVerifier_VerifyRandSeedAndLeaderSignatureVerifyShouldErrWhenVa pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + } err := hdrSigVerifier.VerifyRandSeedAndLeaderSignature(header) require.Equal(t, localErr, err) @@ -269,14 +386,16 @@ func TestHeaderSigVerifier_VerifyRandSeedAndLeaderSignatureVerifyLeaderSigShould pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), LeaderSignature: leaderSig, } @@ -305,29 +424,35 @@ func TestHeaderSigVerifier_VerifyRandSeedAndLeaderSignatureOk(t *testing.T) { pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + } err := hdrSigVerifier.VerifyRandSeedAndLeaderSignature(header) require.Nil(t, err) require.Equal(t, 2, count) } -func TestHeaderSigVerifier_VerifyLeaderSignatureNilRandomnessShouldErr(t *testing.T) { +func TestHeaderSigVerifier_VerifyLeaderSignatureNilPrevRandomnessShouldErr(t *testing.T) { t.Parallel() args := createHeaderSigVerifierArgs() hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: []byte("rand seed "), + PrevRandSeed: nil, + } err := hdrSigVerifier.VerifyLeaderSignature(header) - require.Equal(t, nodesCoordinator.ErrNilRandomness, err) + require.Equal(t, process.ErrNilPrevRandSeed, err) } func TestHeaderSigVerifier_VerifyLeaderSignatureVerifyShouldErrWhenValidationFails(t *testing.T) { @@ -351,14 +476,17 @@ func TestHeaderSigVerifier_VerifyLeaderSignatureVerifyShouldErrWhenValidationFai pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + } err := hdrSigVerifier.VerifyLeaderSignature(header) require.Equal(t, localErr, err) @@ -390,14 +518,16 @@ func TestHeaderSigVerifier_VerifyLeaderSignatureVerifyLeaderSigShouldErr(t *test pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), LeaderSignature: leaderSig, } @@ -426,14 +556,17 @@ func TestHeaderSigVerifier_VerifyLeaderSignatureOk(t *testing.T) { pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + } err := hdrSigVerifier.VerifyLeaderSignature(header) require.Nil(t, err) @@ -445,7 +578,11 @@ func TestHeaderSigVerifier_VerifySignatureNilBitmapShouldErr(t *testing.T) { args := createHeaderSigVerifierArgs() hdrSigVerifier, _ := NewHeaderSigVerifier(args) - header := &dataBlock.Header{} + header := &dataBlock.Header{ + PubKeysBitmap: nil, + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), + } err := hdrSigVerifier.VerifySignature(header) require.Equal(t, process.ErrNilPubKeysBitmap, err) @@ -458,6 +595,8 @@ func TestHeaderSigVerifier_VerifySignatureBlockProposerSigMissingShouldErr(t *te hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ PubKeysBitmap: []byte("0"), + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), } err := hdrSigVerifier.VerifySignature(header) @@ -470,11 +609,12 @@ func TestHeaderSigVerifier_VerifySignatureNilRandomnessShouldErr(t *testing.T) { args := createHeaderSigVerifierArgs() hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ + PrevRandSeed: nil, PubKeysBitmap: []byte("1"), } err := hdrSigVerifier.VerifySignature(header) - require.Equal(t, nodesCoordinator.ErrNilRandomness, err) + require.Equal(t, process.ErrNilPrevRandSeed, err) } func TestHeaderSigVerifier_VerifySignatureWrongSizeBitmapShouldErr(t *testing.T) { @@ -483,9 +623,9 @@ func TestHeaderSigVerifier_VerifySignatureWrongSizeBitmapShouldErr(t *testing.T) args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc @@ -493,10 +633,12 @@ func TestHeaderSigVerifier_VerifySignatureWrongSizeBitmapShouldErr(t *testing.T) hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ PubKeysBitmap: []byte("11"), + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), } err := hdrSigVerifier.VerifySignature(header) - require.Equal(t, ErrWrongSizeBitmap, err) + require.Equal(t, common.ErrWrongSizeBitmap, err) } func TestHeaderSigVerifier_VerifySignatureNotEnoughSigsShouldErr(t *testing.T) { @@ -505,9 +647,9 @@ func TestHeaderSigVerifier_VerifySignatureNotEnoughSigsShouldErr(t *testing.T) { args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v, v, v, v, v}, nil + return v, []nodesCoordinator.Validator{v, v, v, v, v}, nil }, } args.NodesCoordinator = nc @@ -515,10 +657,12 @@ func TestHeaderSigVerifier_VerifySignatureNotEnoughSigsShouldErr(t *testing.T) { hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ PubKeysBitmap: []byte("A"), + RandSeed: []byte("randSeed"), + PrevRandSeed: []byte("prevRandSeed"), } err := hdrSigVerifier.VerifySignature(header) - require.Equal(t, ErrNotEnoughSignatures, err) + require.Equal(t, common.ErrNotEnoughSignatures, err) } func TestHeaderSigVerifier_VerifySignatureOk(t *testing.T) { @@ -528,9 +672,9 @@ func TestHeaderSigVerifier_VerifySignatureOk(t *testing.T) { args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v}, nil + return v, []nodesCoordinator.Validator{v}, nil }, } args.NodesCoordinator = nc @@ -544,6 +688,7 @@ func TestHeaderSigVerifier_VerifySignatureOk(t *testing.T) { hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.Header{ PubKeysBitmap: []byte("1"), + PrevRandSeed: []byte("prevRandSeed"), } err := hdrSigVerifier.VerifySignature(header) @@ -558,9 +703,9 @@ func TestHeaderSigVerifier_VerifySignatureNotEnoughSigsShouldErrWhenFallbackThre args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v, v, v, v, v}, nil + return v, []nodesCoordinator.Validator{v, v, v, v, v}, nil }, } fallbackHeaderValidator := &testscommon.FallBackHeaderValidatorStub{ @@ -582,10 +727,11 @@ func TestHeaderSigVerifier_VerifySignatureNotEnoughSigsShouldErrWhenFallbackThre hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.MetaBlock{ PubKeysBitmap: []byte("C"), + PrevRandSeed: []byte("prevRandSeed"), } err := hdrSigVerifier.VerifySignature(header) - require.Equal(t, ErrNotEnoughSignatures, err) + require.Equal(t, common.ErrNotEnoughSignatures, err) require.False(t, wasCalled) } @@ -596,9 +742,9 @@ func TestHeaderSigVerifier_VerifySignatureOkWhenFallbackThresholdCouldBeApplied( args := createHeaderSigVerifierArgs() pkAddr := []byte("aaa00000000000000000000000000000") nc := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validators []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { v, _ := nodesCoordinator.NewValidator(pkAddr, 1, defaultChancesSelection) - return []nodesCoordinator.Validator{v, v, v, v, v}, nil + return v, []nodesCoordinator.Validator{v, v, v, v, v}, nil }, } fallbackHeaderValidator := &testscommon.FallBackHeaderValidatorStub{ @@ -618,10 +764,360 @@ func TestHeaderSigVerifier_VerifySignatureOkWhenFallbackThresholdCouldBeApplied( hdrSigVerifier, _ := NewHeaderSigVerifier(args) header := &dataBlock.MetaBlock{ - PubKeysBitmap: []byte("C"), + PubKeysBitmap: []byte{15}, + PrevRandSeed: []byte("prevRandSeed"), } err := hdrSigVerifier.VerifySignature(header) require.Nil(t, err) require.True(t, wasCalled) } + +func TestHeaderSigVerifier_VerifySignatureWithEquivalentProofsActivated(t *testing.T) { + wasCalled := false + args := createHeaderSigVerifierArgs() + numValidatorsConsensusBeforeActivation := 7 + numValidatorsConsensusAfterActivation := 10 + eligibleListSize := numValidatorsConsensusAfterActivation + eligibleValidatorsKeys := make([]string, eligibleListSize) + eligibleValidators := make([]nodesCoordinator.Validator, eligibleListSize) + activationEpoch := uint32(1) + + for i := 0; i < eligibleListSize; i++ { + eligibleValidatorsKeys[i] = "pubKey" + strconv.Itoa(i) + eligibleValidators[i], _ = nodesCoordinator.NewValidator([]byte(eligibleValidatorsKeys[i]), 1, defaultChancesSelection) + } + + nc := &shardingMocks.NodesCoordinatorMock{ + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validators []nodesCoordinator.Validator, err error) { + if epoch < activationEpoch { + return eligibleValidators[0], eligibleValidators[:numValidatorsConsensusBeforeActivation], nil + } + return eligibleValidators[0], eligibleValidators, nil + }, + GetAllEligibleValidatorsPublicKeysForShardCalled: func(epoch uint32, shardID uint32) ([]string, error) { + return eligibleValidatorsKeys, nil + }, + } + + t.Run("check transition block", func(t *testing.T) { + enableEpochs := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} + args.EnableEpochsHandler = enableEpochs + enableEpochs.IsFlagEnabledInEpochCalled = func(flag core.EnableEpochFlag, epoch uint32) bool { + return epoch >= activationEpoch + } + enableEpochs.GetActivationEpochCalled = func(flag core.EnableEpochFlag) uint32 { + return activationEpoch + } + + args.NodesCoordinator = nc + args.MultiSigContainer = cryptoMocks.NewMultiSignerContainerMock(&cryptoMocks.MultisignerMock{ + VerifyAggregatedSigCalled: func(pubKeysSigners [][]byte, message []byte, aggSig []byte) error { + wasCalled = true + return nil + }}) + hdrSigVerifier, _ := NewHeaderSigVerifier(args) + header := &dataBlock.HeaderV2{ + Header: &dataBlock.Header{ + ShardID: 0, + PrevRandSeed: []byte("prevRandSeed"), + PubKeysBitmap: nil, + Signature: nil, + Epoch: 1, + EpochStartMetaHash: []byte("epoch start meta hash"), // to make this the epoch start block in the shard + + }, + } + + err := hdrSigVerifier.VerifySignature(header) + require.Nil(t, err) + require.False(t, wasCalled) + + // check current block proof + err = hdrSigVerifier.VerifyHeaderProof(&dataBlock.HeaderProof{ + PubKeysBitmap: []byte{0xff}, // bitmap should still have the old format + AggregatedSignature: []byte("aggregated signature"), + HeaderHash: []byte("hash"), + HeaderEpoch: 1, + IsStartOfEpoch: true, + }) + require.Nil(t, err) + }) +} + +func getFilledHeader() data.HeaderHandler { + return &dataBlock.Header{ + PrevHash: []byte("prev hash"), + PrevRandSeed: []byte("prev rand seed"), + RandSeed: []byte("rand seed"), + PubKeysBitmap: []byte{0xFF}, + LeaderSignature: []byte("leader signature"), + } +} + +func TestHeaderSigVerifier_VerifyHeaderProof(t *testing.T) { + t.Parallel() + + t.Run("nil proof should error", func(t *testing.T) { + t.Parallel() + + args := createHeaderSigVerifierArgs() + args.EnableEpochsHandler = enableEpochsHandlerMock.NewEnableEpochsHandlerStub(common.AndromedaFlag) + hdrSigVerifier, err := NewHeaderSigVerifier(args) + require.NoError(t, err) + + err = hdrSigVerifier.VerifyHeaderProof(nil) + require.Equal(t, process.ErrNilHeaderProof, err) + }) + t.Run("flag not active should error", func(t *testing.T) { + t.Parallel() + + hdrSigVerifier, err := NewHeaderSigVerifier(createHeaderSigVerifierArgs()) + require.NoError(t, err) + + err = hdrSigVerifier.VerifyHeaderProof(&dataBlock.HeaderProof{ + PubKeysBitmap: []byte{3}, + }) + require.True(t, errors.Is(err, process.ErrFlagNotActive)) + require.True(t, strings.Contains(err.Error(), string(common.AndromedaFlag))) + }) + t.Run("GetMultiSigner error should error", func(t *testing.T) { + t.Parallel() + + cnt := 0 + args := createHeaderSigVerifierArgs() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + args.MultiSigContainer = &cryptoMocks.MultiSignerContainerStub{ + GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { + cnt++ + if cnt > 1 { + return nil, expectedErr + } + return &cryptoMocks.MultiSignerStub{}, nil + }, + } + hdrSigVerifier, err := NewHeaderSigVerifier(args) + require.NoError(t, err) + + err = hdrSigVerifier.VerifyHeaderProof(&dataBlock.HeaderProof{}) + require.Equal(t, expectedErr, err) + }) + t.Run("getConsensusSignersForEquivalentProofs error should error", func(t *testing.T) { + t.Parallel() + + headerHash := []byte("header hash") + wasVerifyAggregatedSigCalled := false + args := createHeaderSigVerifierArgs() + args.HeadersPool = &mock.HeadersCacherStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return getFilledHeader(), nil + }, + } + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + args.MultiSigContainer = &cryptoMocks.MultiSignerContainerStub{ + GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { + return &cryptoMocks.MultiSignerStub{ + VerifyAggregatedSigCalled: func(pubKeysSigners [][]byte, message []byte, aggSig []byte) error { + wasVerifyAggregatedSigCalled = true + return nil + }, + }, nil + }, + } + args.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ + GetAllEligibleValidatorsPublicKeysForShardCalled: func(epoch uint32, shardID uint32) ([]string, error) { + return nil, expectedErr + }, + } + hdrSigVerifier, err := NewHeaderSigVerifier(args) + require.NoError(t, err) + + err = hdrSigVerifier.VerifyHeaderProof(&dataBlock.HeaderProof{ + PubKeysBitmap: []byte{0x3}, + AggregatedSignature: make([]byte, 10), + HeaderHash: headerHash, + }) + require.Equal(t, expectedErr, err) + require.False(t, wasVerifyAggregatedSigCalled) + }) + t.Run("should try multiple times to get header if not available", func(t *testing.T) { + t.Parallel() + + headerHash := []byte("header hash") + wasVerifyAggregatedSigCalled := false + args := createHeaderSigVerifierArgs() + + args.StorageService = &testscommonStorage.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + return &testscommonStorage.StorerStub{ + SearchFirstCalled: func(key []byte) ([]byte, error) { + return nil, errors.New("not found") + }, + }, nil + }, + } + + numCalls := 0 + args.HeadersPool = &mock.HeadersCacherStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + if numCalls < 2 { + numCalls++ + return nil, errors.New("not found") + } + + return getFilledHeader(), nil + }, + } + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + args.MultiSigContainer = &cryptoMocks.MultiSignerContainerStub{ + GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { + return &cryptoMocks.MultiSignerStub{ + VerifyAggregatedSigCalled: func(pubKeysSigners [][]byte, message []byte, aggSig []byte) error { + wasVerifyAggregatedSigCalled = true + return nil + }, + }, nil + }, + } + hdrSigVerifier, err := NewHeaderSigVerifier(args) + require.NoError(t, err) + + err = hdrSigVerifier.VerifyHeaderProof(&dataBlock.HeaderProof{ + PubKeysBitmap: []byte{0x3}, + AggregatedSignature: make([]byte, 10), + HeaderHash: headerHash, + IsStartOfEpoch: true, + }) + require.NoError(t, err) + require.True(t, wasVerifyAggregatedSigCalled) + + require.Equal(t, 2, numCalls) + }) + + t.Run("should work", func(t *testing.T) { + t.Parallel() + headerHash := []byte("header hash") + wasVerifyAggregatedSigCalled := false + args := createHeaderSigVerifierArgs() + args.HeadersPool = &mock.HeadersCacherStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + return getFilledHeader(), nil + }, + } + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + args.MultiSigContainer = &cryptoMocks.MultiSignerContainerStub{ + GetMultiSignerCalled: func(epoch uint32) (crypto.MultiSigner, error) { + return &cryptoMocks.MultiSignerStub{ + VerifyAggregatedSigCalled: func(pubKeysSigners [][]byte, message []byte, aggSig []byte) error { + wasVerifyAggregatedSigCalled = true + return nil + }, + }, nil + }, + } + hdrSigVerifier, err := NewHeaderSigVerifier(args) + require.NoError(t, err) + + err = hdrSigVerifier.VerifyHeaderProof(&dataBlock.HeaderProof{ + PubKeysBitmap: []byte{0x3}, + AggregatedSignature: make([]byte, 10), + HeaderHash: headerHash, + }) + require.NoError(t, err) + require.True(t, wasVerifyAggregatedSigCalled) + }) +} + +func TestHeaderSigVerifier_getConsensusSignersForEquivalentProofs(t *testing.T) { + t.Parallel() + + t.Run("nil proof should error", func(t *testing.T) { + t.Parallel() + + hdrSigVerifier, _ := NewHeaderSigVerifier(createHeaderSigVerifierArgs()) + require.NotNil(t, hdrSigVerifier) + + signers, err := hdrSigVerifier.getConsensusSignersForEquivalentProofs(nil) + require.Nil(t, signers) + require.Equal(t, process.ErrNilHeaderProof, err) + }) + t.Run("flag not active should error", func(t *testing.T) { + t.Parallel() + + args := createHeaderSigVerifierArgs() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + } + hdrSigVerifier, _ := NewHeaderSigVerifier(args) + require.NotNil(t, hdrSigVerifier) + + signers, err := hdrSigVerifier.getConsensusSignersForEquivalentProofs(&dataBlock.HeaderProof{}) + require.Nil(t, signers) + require.Equal(t, process.ErrUnexpectedHeaderProof, err) + }) + t.Run("nodesCoordinator error should error", func(t *testing.T) { + t.Parallel() + + args := createHeaderSigVerifierArgs() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + args.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ + GetAllEligibleValidatorsPublicKeysForShardCalled: func(epoch uint32, shardID uint32) ([]string, error) { + return nil, expectedErr + }, + } + hdrSigVerifier, _ := NewHeaderSigVerifier(args) + require.NotNil(t, hdrSigVerifier) + + signers, err := hdrSigVerifier.getConsensusSignersForEquivalentProofs(&dataBlock.HeaderProof{ + IsStartOfEpoch: true, // for coverage only + HeaderEpoch: 1, + }) + require.Nil(t, signers) + require.Equal(t, expectedErr, err) + }) + t.Run("invalid consensus bitmap error should error", func(t *testing.T) { + t.Parallel() + + args := createHeaderSigVerifierArgs() + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + args.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ + GetAllEligibleValidatorsPublicKeysForShardCalled: func(epoch uint32, shardID uint32) ([]string, error) { + return []string{"pk1", "pk2", "pk3", "pk4", "pk5", "pk6", "pk7", "pk8"}, nil + }, + } + hdrSigVerifier, _ := NewHeaderSigVerifier(args) + require.NotNil(t, hdrSigVerifier) + + signers, err := hdrSigVerifier.getConsensusSignersForEquivalentProofs(&dataBlock.HeaderProof{ + PubKeysBitmap: []byte{1, 1}, + }) + require.Nil(t, signers) + require.Equal(t, common.ErrWrongSizeBitmap, err) + }) +} diff --git a/process/interceptors/baseDataInterceptor.go b/process/interceptors/baseDataInterceptor.go index 64efb852238..cec00abd756 100644 --- a/process/interceptors/baseDataInterceptor.go +++ b/process/interceptors/baseDataInterceptor.go @@ -6,19 +6,21 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process" ) type baseDataInterceptor struct { - throttler process.InterceptorThrottler - antifloodHandler process.P2PAntifloodHandler - topic string - currentPeerId core.PeerID - processor process.InterceptorProcessor - mutDebugHandler sync.RWMutex - debugHandler process.InterceptedDebugger - preferredPeersHolder process.PreferredPeersHolderHandler + throttler process.InterceptorThrottler + antifloodHandler process.P2PAntifloodHandler + topic string + currentPeerId core.PeerID + processor process.InterceptorProcessor + mutDebugHandler sync.RWMutex + debugHandler process.InterceptedDebugger + preferredPeersHolder process.PreferredPeersHolderHandler + interceptedDataVerifier process.InterceptedDataVerifier } func (bdi *baseDataInterceptor) preProcessMesage(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error { diff --git a/process/interceptors/epochStartMetaBlockInterceptor.go b/process/interceptors/epochStartMetaBlockInterceptor.go index 36bfc121988..0559c1cf7ef 100644 --- a/process/interceptors/epochStartMetaBlockInterceptor.go +++ b/process/interceptors/epochStartMetaBlockInterceptor.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/factory" @@ -56,22 +57,22 @@ func NewEpochStartMetaBlockInterceptor(args ArgsEpochStartMetaBlockInterceptor) } // ProcessReceivedMessage will handle received messages containing epoch start meta blocks -func (e *epochStartMetaBlockInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, _ p2p.MessageHandler) error { +func (e *epochStartMetaBlockInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, _ p2p.MessageHandler) ([]byte, error) { var epochStartMb block.MetaBlock err := e.marshalizer.Unmarshal(&epochStartMb, message.Data()) if err != nil { - return err + return nil, err } mbHash, err := core.CalculateHash(e.marshalizer, e.hasher, epochStartMb) if err != nil { - return err + return nil, err } if !epochStartMb.IsStartOfEpochBlock() { log.Trace("epochStartMetaBlockInterceptor-ProcessReceivedMessage: received meta block is not of "+ "type epoch start meta block", "hash", mbHash) - return process.ErrNotEpochStartBlock + return nil, process.ErrNotEpochStartBlock } log.Trace("received epoch start meta", "epoch", epochStartMb.GetEpoch(), "from peer", fromConnectedPeer.Pretty()) @@ -82,11 +83,11 @@ func (e *epochStartMetaBlockInterceptor) ProcessReceivedMessage(message p2p.Mess metaBlock, found := e.checkMaps() if !found { - return nil + return mbHash, nil } e.handleFoundEpochStartMetaBlock(metaBlock) - return nil + return mbHash, nil } // this func should be called under mutex protection diff --git a/process/interceptors/epochStartMetaBlockInterceptor_test.go b/process/interceptors/epochStartMetaBlockInterceptor_test.go index 6958be19f8c..e4c916c8a6d 100644 --- a/process/interceptors/epochStartMetaBlockInterceptor_test.go +++ b/process/interceptors/epochStartMetaBlockInterceptor_test.go @@ -7,12 +7,13 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/require" ) func TestNewEpochStartMetaBlockInterceptor(t *testing.T) { @@ -100,8 +101,9 @@ func TestEpochStartMetaBlockInterceptor_ProcessReceivedMessageUnmarshalError(t * require.NotNil(t, esmbi) message := &p2pmocks.P2PMessageMock{DataField: []byte("wrong meta block bytes")} - err := esmbi.ProcessReceivedMessage(message, "", &p2pmocks.MessengerStub{}) + msgID, err := esmbi.ProcessReceivedMessage(message, "", &p2pmocks.MessengerStub{}) require.Error(t, err) + require.Nil(t, msgID) } func TestEpochStartMetaBlockInterceptor_EntireFlowShouldWorkAndSetTheEpoch(t *testing.T) { @@ -144,23 +146,24 @@ func TestEpochStartMetaBlockInterceptor_EntireFlowShouldWorkAndSetTheEpoch(t *te wrongMetaBlock := &block.MetaBlock{Epoch: 0} wrongMetaBlockBytes, _ := args.Marshalizer.Marshal(wrongMetaBlock) - err := esmbi.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{DataField: metaBlockBytes}, "peer0", &p2pmocks.MessengerStub{}) + msgID, err := esmbi.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{DataField: metaBlockBytes}, "peer0", &p2pmocks.MessengerStub{}) require.NoError(t, err) require.False(t, wasCalled) + require.NotNil(t, msgID) - _ = esmbi.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{DataField: metaBlockBytes}, "peer1", &p2pmocks.MessengerStub{}) + _, _ = esmbi.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{DataField: metaBlockBytes}, "peer1", &p2pmocks.MessengerStub{}) require.False(t, wasCalled) // send again from peer1 => should not be taken into account - _ = esmbi.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{DataField: metaBlockBytes}, "peer1", &p2pmocks.MessengerStub{}) + _, _ = esmbi.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{DataField: metaBlockBytes}, "peer1", &p2pmocks.MessengerStub{}) require.False(t, wasCalled) // send another meta block - _ = esmbi.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{DataField: wrongMetaBlockBytes}, "peer2", &p2pmocks.MessengerStub{}) + _, _ = esmbi.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{DataField: wrongMetaBlockBytes}, "peer2", &p2pmocks.MessengerStub{}) require.False(t, wasCalled) // send the last needed metablock from a new peer => should fetch the epoch - _ = esmbi.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{DataField: metaBlockBytes}, "peer3", &p2pmocks.MessengerStub{}) + _, _ = esmbi.ProcessReceivedMessage(&p2pmocks.P2PMessageMock{DataField: metaBlockBytes}, "peer3", &p2pmocks.MessengerStub{}) require.True(t, wasCalled) } diff --git a/process/interceptors/factory/argInterceptedDataFactory.go b/process/interceptors/factory/argInterceptedDataFactory.go index 37701a92f7a..dbc7350436d 100644 --- a/process/interceptors/factory/argInterceptedDataFactory.go +++ b/process/interceptors/factory/argInterceptedDataFactory.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" @@ -26,6 +27,8 @@ type interceptedDataCoreComponentsHolder interface { IsInterfaceNil() bool HardforkTriggerPubKey() []byte EnableEpochsHandler() common.EnableEpochsHandler + EpochChangeGracePeriodHandler() common.EpochChangeGracePeriodHandler + FieldsSizeChecker() common.FieldsSizeChecker } // interceptedDataCryptoComponentsHolder holds the crypto components required by the intercepted data factory diff --git a/process/interceptors/factory/interceptedDataVerifierFactory.go b/process/interceptors/factory/interceptedDataVerifierFactory.go new file mode 100644 index 00000000000..2775bbdc61a --- /dev/null +++ b/process/interceptors/factory/interceptedDataVerifierFactory.go @@ -0,0 +1,72 @@ +package factory + +import ( + "fmt" + "sync" + "time" + + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/process/interceptors" + "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/storage/cache" +) + +// InterceptedDataVerifierFactoryArgs holds the required arguments for interceptedDataVerifierFactory +type InterceptedDataVerifierFactoryArgs struct { + CacheSpan time.Duration + CacheExpiry time.Duration +} + +// interceptedDataVerifierFactory encapsulates the required arguments to create InterceptedDataVerifier +// Furthermore it will hold all such instances in an internal map. +type interceptedDataVerifierFactory struct { + cacheSpan time.Duration + cacheExpiry time.Duration + + interceptedDataVerifierMap map[string]storage.Cacher + mutex sync.Mutex +} + +// NewInterceptedDataVerifierFactory will create a factory instance that will create instance of InterceptedDataVerifiers +func NewInterceptedDataVerifierFactory(args InterceptedDataVerifierFactoryArgs) *interceptedDataVerifierFactory { + return &interceptedDataVerifierFactory{ + cacheSpan: args.CacheSpan, + cacheExpiry: args.CacheExpiry, + interceptedDataVerifierMap: make(map[string]storage.Cacher), + mutex: sync.Mutex{}, + } +} + +// Create will return an instance of InterceptedDataVerifier +func (idvf *interceptedDataVerifierFactory) Create(topic string) (process.InterceptedDataVerifier, error) { + internalCache, err := cache.NewTimeCacher(cache.ArgTimeCacher{ + DefaultSpan: idvf.cacheSpan, + CacheExpiry: idvf.cacheExpiry, + }) + if err != nil { + return nil, err + } + + idvf.mutex.Lock() + idvf.interceptedDataVerifierMap[topic] = internalCache + idvf.mutex.Unlock() + + return interceptors.NewInterceptedDataVerifier(internalCache) +} + +// Close will close all the sweeping routines created by the cache. +func (idvf *interceptedDataVerifierFactory) Close() error { + for topic, cacher := range idvf.interceptedDataVerifierMap { + err := cacher.Close() + if err != nil { + return fmt.Errorf("failed to close cacher on topic %q: %w", topic, err) + } + } + + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (idvf *interceptedDataVerifierFactory) IsInterfaceNil() bool { + return idvf == nil +} diff --git a/process/interceptors/factory/interceptedDataVerifierFactory_test.go b/process/interceptors/factory/interceptedDataVerifierFactory_test.go new file mode 100644 index 00000000000..45f42ec05fd --- /dev/null +++ b/process/interceptors/factory/interceptedDataVerifierFactory_test.go @@ -0,0 +1,44 @@ +package factory + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func createMockArgInterceptedDataVerifierFactory() InterceptedDataVerifierFactoryArgs { + return InterceptedDataVerifierFactoryArgs{ + CacheSpan: time.Second, + CacheExpiry: time.Second, + } +} + +func TestInterceptedDataVerifierFactory_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var factory *interceptedDataVerifierFactory + require.True(t, factory.IsInterfaceNil()) + + factory = NewInterceptedDataVerifierFactory(createMockArgInterceptedDataVerifierFactory()) + require.False(t, factory.IsInterfaceNil()) +} + +func TestNewInterceptedDataVerifierFactory(t *testing.T) { + t.Parallel() + + factory := NewInterceptedDataVerifierFactory(createMockArgInterceptedDataVerifierFactory()) + require.NotNil(t, factory) +} + +func TestInterceptedDataVerifierFactory_Create(t *testing.T) { + t.Parallel() + + factory := NewInterceptedDataVerifierFactory(createMockArgInterceptedDataVerifierFactory()) + require.NotNil(t, factory) + + interceptedDataVerifier, err := factory.Create("mockTopic") + require.NoError(t, err) + + require.False(t, interceptedDataVerifier.IsInterfaceNil()) +} diff --git a/process/interceptors/factory/interceptedEquivalentProofsFactory.go b/process/interceptors/factory/interceptedEquivalentProofsFactory.go new file mode 100644 index 00000000000..17822096283 --- /dev/null +++ b/process/interceptors/factory/interceptedEquivalentProofsFactory.go @@ -0,0 +1,64 @@ +package factory + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/sync" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/dataRetriever" + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" + "github.com/multiversx/mx-chain-go/sharding" +) + +// ArgInterceptedEquivalentProofsFactory is the DTO used to create a new instance of interceptedEquivalentProofsFactory +type ArgInterceptedEquivalentProofsFactory struct { + ArgInterceptedDataFactory + ProofsPool dataRetriever.ProofsPool +} + +type interceptedEquivalentProofsFactory struct { + marshaller marshal.Marshalizer + shardCoordinator sharding.Coordinator + headerSigVerifier consensus.HeaderSigVerifier + proofsPool dataRetriever.ProofsPool + hasher hashing.Hasher + proofSizeChecker common.FieldsSizeChecker + km sync.KeyRWMutexHandler +} + +// NewInterceptedEquivalentProofsFactory creates a new instance of interceptedEquivalentProofsFactory +func NewInterceptedEquivalentProofsFactory(args ArgInterceptedEquivalentProofsFactory) *interceptedEquivalentProofsFactory { + return &interceptedEquivalentProofsFactory{ + marshaller: args.CoreComponents.InternalMarshalizer(), + shardCoordinator: args.ShardCoordinator, + headerSigVerifier: args.HeaderSigVerifier, + proofsPool: args.ProofsPool, + hasher: args.CoreComponents.Hasher(), + proofSizeChecker: args.CoreComponents.FieldsSizeChecker(), + km: sync.NewKeyRWMutex(), + } +} + +// Create creates instances of InterceptedData by unmarshalling provided buffer +func (factory *interceptedEquivalentProofsFactory) Create(buff []byte, _ core.PeerID) (process.InterceptedData, error) { + args := interceptedBlocks.ArgInterceptedEquivalentProof{ + DataBuff: buff, + Marshaller: factory.marshaller, + ShardCoordinator: factory.shardCoordinator, + HeaderSigVerifier: factory.headerSigVerifier, + Proofs: factory.proofsPool, + Hasher: factory.hasher, + ProofSizeChecker: factory.proofSizeChecker, + KeyRWMutexHandler: factory.km, + } + return interceptedBlocks.NewInterceptedEquivalentProof(args) +} + +// IsInterfaceNil returns true if there is no value under the interface +func (factory *interceptedEquivalentProofsFactory) IsInterfaceNil() bool { + return factory == nil +} diff --git a/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go b/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go new file mode 100644 index 00000000000..a972f24bfe8 --- /dev/null +++ b/process/interceptors/factory/interceptedEquivalentProofsFactory_test.go @@ -0,0 +1,86 @@ +package factory + +import ( + "testing" + + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/consensus/mock" + processMock "github.com/multiversx/mx-chain-go/process/mock" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" +) + +func createMockArgInterceptedEquivalentProofsFactory() ArgInterceptedEquivalentProofsFactory { + return ArgInterceptedEquivalentProofsFactory{ + ArgInterceptedDataFactory: ArgInterceptedDataFactory{ + CoreComponents: &processMock.CoreComponentsMock{ + IntMarsh: &mock.MarshalizerMock{}, + Hash: &hashingMocks.HasherMock{}, + FieldsSizeCheckerField: &testscommon.FieldsSizeCheckerMock{}, + }, + ShardCoordinator: &mock.ShardCoordinatorMock{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + NodesCoordinator: &shardingMocks.NodesCoordinatorStub{}, + }, + ProofsPool: &dataRetriever.ProofsPoolMock{}, + } +} + +func TestInterceptedEquivalentProofsFactory_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var factory *interceptedEquivalentProofsFactory + require.True(t, factory.IsInterfaceNil()) + + factory = NewInterceptedEquivalentProofsFactory(createMockArgInterceptedEquivalentProofsFactory()) + require.False(t, factory.IsInterfaceNil()) +} + +func TestNewInterceptedEquivalentProofsFactory(t *testing.T) { + t.Parallel() + + factory := NewInterceptedEquivalentProofsFactory(createMockArgInterceptedEquivalentProofsFactory()) + require.NotNil(t, factory) +} + +func TestInterceptedEquivalentProofsFactory_Create(t *testing.T) { + t.Parallel() + + args := createMockArgInterceptedEquivalentProofsFactory() + factory := NewInterceptedEquivalentProofsFactory(args) + require.NotNil(t, factory) + + providedProof := &block.HeaderProof{ + PubKeysBitmap: []byte("bitmap"), + AggregatedSignature: []byte("sig"), + HeaderHash: []byte("hash"), + HeaderEpoch: 123, + HeaderNonce: 345, + HeaderShardId: 0, + } + providedDataBuff, _ := args.CoreComponents.InternalMarshalizer().Marshal(providedProof) + interceptedData, err := factory.Create(providedDataBuff, "") + require.NoError(t, err) + require.NotNil(t, interceptedData) + + type interceptedEquivalentProof interface { + GetProof() data.HeaderProofHandler + } + interceptedHeaderProof, ok := interceptedData.(interceptedEquivalentProof) + require.True(t, ok) + + proof := interceptedHeaderProof.GetProof() + require.NotNil(t, proof) + require.Equal(t, providedProof.GetPubKeysBitmap(), proof.GetPubKeysBitmap()) + require.Equal(t, providedProof.GetAggregatedSignature(), proof.GetAggregatedSignature()) + require.Equal(t, providedProof.GetHeaderHash(), proof.GetHeaderHash()) + require.Equal(t, providedProof.GetHeaderEpoch(), proof.GetHeaderEpoch()) + require.Equal(t, providedProof.GetHeaderNonce(), proof.GetHeaderNonce()) + require.Equal(t, providedProof.GetHeaderShardId(), proof.GetHeaderShardId()) +} diff --git a/process/interceptors/factory/interceptedHeartbeatDataFactory.go b/process/interceptors/factory/interceptedHeartbeatDataFactory.go index 9956e138a05..5d5c3fed2e6 100644 --- a/process/interceptors/factory/interceptedHeartbeatDataFactory.go +++ b/process/interceptors/factory/interceptedHeartbeatDataFactory.go @@ -29,7 +29,7 @@ func NewInterceptedHeartbeatDataFactory(arg ArgInterceptedDataFactory) (*interce } // Create creates instances of InterceptedData by unmarshalling provided buffer -func (ihdf *interceptedHeartbeatDataFactory) Create(buff []byte) (process.InterceptedData, error) { +func (ihdf *interceptedHeartbeatDataFactory) Create(buff []byte, _ core.PeerID) (process.InterceptedData, error) { arg := heartbeat.ArgBaseInterceptedHeartbeat{ DataBuff: buff, Marshaller: ihdf.marshalizer, diff --git a/process/interceptors/factory/interceptedHeartbeatDataFactory_test.go b/process/interceptors/factory/interceptedHeartbeatDataFactory_test.go index 5e4af0a0ce5..055830b685d 100644 --- a/process/interceptors/factory/interceptedHeartbeatDataFactory_test.go +++ b/process/interceptors/factory/interceptedHeartbeatDataFactory_test.go @@ -67,7 +67,7 @@ func TestNewInterceptedHeartbeatDataFactory(t *testing.T) { marshaledHeartbeat, err := marshaller.Marshal(hb) assert.Nil(t, err) - interceptedData, err := ihdf.Create(marshaledHeartbeat) + interceptedData, err := ihdf.Create(marshaledHeartbeat, "") assert.NotNil(t, interceptedData) assert.Nil(t, err) assert.True(t, strings.Contains(fmt.Sprintf("%T", interceptedData), "*heartbeat.interceptedHeartbeat")) diff --git a/process/interceptors/factory/interceptedMetaHeaderDataFactory.go b/process/interceptors/factory/interceptedMetaHeaderDataFactory.go index 7567727571d..7068734cd72 100644 --- a/process/interceptors/factory/interceptedMetaHeaderDataFactory.go +++ b/process/interceptors/factory/interceptedMetaHeaderDataFactory.go @@ -1,9 +1,12 @@ package factory import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/sharding" @@ -11,18 +14,25 @@ import ( var _ process.InterceptedDataFactory = (*interceptedMetaHeaderDataFactory)(nil) +// ArgInterceptedMetaHeaderFactory is the DTO used to create a new instance of meta header factory +type ArgInterceptedMetaHeaderFactory struct { + ArgInterceptedDataFactory +} + type interceptedMetaHeaderDataFactory struct { - marshalizer marshal.Marshalizer - hasher hashing.Hasher - shardCoordinator sharding.Coordinator - headerSigVerifier process.InterceptedHeaderSigVerifier - headerIntegrityVerifier process.HeaderIntegrityVerifier - validityAttester process.ValidityAttester - epochStartTrigger process.EpochStartTriggerHandler + marshalizer marshal.Marshalizer + hasher hashing.Hasher + shardCoordinator sharding.Coordinator + headerSigVerifier process.InterceptedHeaderSigVerifier + headerIntegrityVerifier process.HeaderIntegrityVerifier + validityAttester process.ValidityAttester + epochStartTrigger process.EpochStartTriggerHandler + enableEpochsHandler common.EnableEpochsHandler + epochChangeGracePeriodHandler common.EpochChangeGracePeriodHandler } // NewInterceptedMetaHeaderDataFactory creates an instance of interceptedMetaHeaderDataFactory -func NewInterceptedMetaHeaderDataFactory(argument *ArgInterceptedDataFactory) (*interceptedMetaHeaderDataFactory, error) { +func NewInterceptedMetaHeaderDataFactory(argument *ArgInterceptedMetaHeaderFactory) (*interceptedMetaHeaderDataFactory, error) { if argument == nil { return nil, process.ErrNilArgumentStruct } @@ -58,27 +68,31 @@ func NewInterceptedMetaHeaderDataFactory(argument *ArgInterceptedDataFactory) (* } return &interceptedMetaHeaderDataFactory{ - marshalizer: argument.CoreComponents.InternalMarshalizer(), - hasher: argument.CoreComponents.Hasher(), - shardCoordinator: argument.ShardCoordinator, - headerSigVerifier: argument.HeaderSigVerifier, - headerIntegrityVerifier: argument.HeaderIntegrityVerifier, - validityAttester: argument.ValidityAttester, - epochStartTrigger: argument.EpochStartTrigger, + marshalizer: argument.CoreComponents.InternalMarshalizer(), + hasher: argument.CoreComponents.Hasher(), + shardCoordinator: argument.ShardCoordinator, + headerSigVerifier: argument.HeaderSigVerifier, + headerIntegrityVerifier: argument.HeaderIntegrityVerifier, + validityAttester: argument.ValidityAttester, + epochStartTrigger: argument.EpochStartTrigger, + enableEpochsHandler: argument.CoreComponents.EnableEpochsHandler(), + epochChangeGracePeriodHandler: argument.CoreComponents.EpochChangeGracePeriodHandler(), }, nil } // Create creates instances of InterceptedData by unmarshalling provided buffer -func (imhdf *interceptedMetaHeaderDataFactory) Create(buff []byte) (process.InterceptedData, error) { +func (imhdf *interceptedMetaHeaderDataFactory) Create(buff []byte, _ core.PeerID) (process.InterceptedData, error) { arg := &interceptedBlocks.ArgInterceptedBlockHeader{ - HdrBuff: buff, - Marshalizer: imhdf.marshalizer, - Hasher: imhdf.hasher, - ShardCoordinator: imhdf.shardCoordinator, - HeaderSigVerifier: imhdf.headerSigVerifier, - HeaderIntegrityVerifier: imhdf.headerIntegrityVerifier, - ValidityAttester: imhdf.validityAttester, - EpochStartTrigger: imhdf.epochStartTrigger, + HdrBuff: buff, + Marshalizer: imhdf.marshalizer, + Hasher: imhdf.hasher, + ShardCoordinator: imhdf.shardCoordinator, + HeaderSigVerifier: imhdf.headerSigVerifier, + HeaderIntegrityVerifier: imhdf.headerIntegrityVerifier, + ValidityAttester: imhdf.validityAttester, + EpochStartTrigger: imhdf.epochStartTrigger, + EnableEpochsHandler: imhdf.enableEpochsHandler, + EpochChangeGracePeriodHandler: imhdf.epochChangeGracePeriodHandler, } return interceptedBlocks.NewInterceptedMetaHeader(arg) diff --git a/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go b/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go index 1930cb9a140..f962fb9806e 100644 --- a/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go +++ b/process/interceptors/factory/interceptedMetaHeaderDataFactory_test.go @@ -10,18 +10,22 @@ import ( "github.com/multiversx/mx-chain-core-go/core/versioning" "github.com/multiversx/mx-chain-core-go/data/block" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/common/graceperiod" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/process/mock" processMocks "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/consensus" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/assert" ) var errSingleSignKeyGenMock = errors.New("errSingleSignKeyGenMock") @@ -60,6 +64,7 @@ func createMockFeeHandler() process.FeeHandler { } func createMockComponentHolders() (*mock.CoreComponentsMock, *mock.CryptoComponentsMock) { + gracePeriod, _ := graceperiod.NewEpochChangeGracePeriod([]config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}) coreComponents := &mock.CoreComponentsMock{ IntMarsh: &mock.MarshalizerMock{}, TxMarsh: &mock.MarshalizerMock{}, @@ -70,10 +75,12 @@ func createMockComponentHolders() (*mock.CoreComponentsMock, *mock.CryptoCompone ChainIdCalled: func() string { return "chainID" }, - TxVersionCheckField: versioning.NewTxVersionChecker(1), - EpochNotifierField: &epochNotifier.EpochNotifierStub{}, - HardforkTriggerPubKeyField: []byte("provided hardfork pub key"), - EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + TxVersionCheckField: versioning.NewTxVersionChecker(1), + EpochNotifierField: &epochNotifier.EpochNotifierStub{}, + HardforkTriggerPubKeyField: []byte("provided hardfork pub key"), + EnableEpochsHandlerField: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + FieldsSizeCheckerField: &testscommon.FieldsSizeCheckerMock{}, + EpochChangeGracePeriodHandlerField: gracePeriod, } cryptoComponents := &mock.CryptoComponentsMock{ BlockSig: createMockSigner(), @@ -86,6 +93,31 @@ func createMockComponentHolders() (*mock.CoreComponentsMock, *mock.CryptoCompone return coreComponents, cryptoComponents } +func createMockArgMetaHeaderFactoryArgument( + coreComponents *mock.CoreComponentsMock, + cryptoComponents *mock.CryptoComponentsMock, +) *ArgInterceptedMetaHeaderFactory { + return &ArgInterceptedMetaHeaderFactory{ + ArgInterceptedDataFactory: ArgInterceptedDataFactory{ + CoreComponents: coreComponents, + CryptoComponents: cryptoComponents, + ShardCoordinator: mock.NewOneShardCoordinatorMock(), + NodesCoordinator: shardingMocks.NewNodesCoordinatorMock(), + FeeHandler: createMockFeeHandler(), + WhiteListerVerifiedTxs: &testscommon.WhiteListHandlerStub{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, + ValidityAttester: &mock.ValidityAttesterStub{}, + HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, + EpochStartTrigger: &mock.EpochStartTriggerStub{}, + ArgsParser: &testscommon.ArgumentParserMock{}, + PeerSignatureHandler: &processMocks.PeerSignatureHandlerStub{}, + SignaturesHandler: &processMocks.SignaturesHandlerStub{}, + HeartbeatExpiryTimespanInSec: 30, + PeerID: "pid", + }, + } +} + func createMockArgument( coreComponents *mock.CoreComponentsMock, cryptoComponents *mock.CryptoComponentsMock, @@ -97,7 +129,7 @@ func createMockArgument( NodesCoordinator: shardingMocks.NewNodesCoordinatorMock(), FeeHandler: createMockFeeHandler(), WhiteListerVerifiedTxs: &testscommon.WhiteListHandlerStub{}, - HeaderSigVerifier: &mock.HeaderSigVerifierStub{}, + HeaderSigVerifier: &consensus.HeaderSigVerifierMock{}, ValidityAttester: &mock.ValidityAttesterStub{}, HeaderIntegrityVerifier: &mock.HeaderIntegrityVerifierStub{}, EpochStartTrigger: &mock.EpochStartTriggerStub{}, @@ -123,7 +155,7 @@ func TestNewInterceptedMetaHeaderDataFactory_NilMarshalizerShouldErr(t *testing. coreComp, cryptoComp := createMockComponentHolders() coreComp.IntMarsh = nil - arg := createMockArgument(coreComp, cryptoComp) + arg := createMockArgMetaHeaderFactoryArgument(coreComp, cryptoComp) imh, err := NewInterceptedMetaHeaderDataFactory(arg) assert.Nil(t, imh) @@ -135,7 +167,7 @@ func TestNewInterceptedMetaHeaderDataFactory_NilSignMarshalizerShouldErr(t *test coreComp, cryptoComp := createMockComponentHolders() coreComp.TxMarsh = nil - arg := createMockArgument(coreComp, cryptoComp) + arg := createMockArgMetaHeaderFactoryArgument(coreComp, cryptoComp) imh, err := NewInterceptedMetaHeaderDataFactory(arg) assert.True(t, check.IfNil(imh)) @@ -147,7 +179,7 @@ func TestNewInterceptedMetaHeaderDataFactory_NilHasherShouldErr(t *testing.T) { coreComp, cryptoComp := createMockComponentHolders() coreComp.Hash = nil - arg := createMockArgument(coreComp, cryptoComp) + arg := createMockArgMetaHeaderFactoryArgument(coreComp, cryptoComp) imh, err := NewInterceptedMetaHeaderDataFactory(arg) assert.True(t, check.IfNil(imh)) @@ -157,7 +189,7 @@ func TestNewInterceptedMetaHeaderDataFactory_NilHasherShouldErr(t *testing.T) { func TestNewInterceptedMetaHeaderDataFactory_NilHeaderSigVerifierShouldErr(t *testing.T) { t.Parallel() coreComp, cryptoComp := createMockComponentHolders() - arg := createMockArgument(coreComp, cryptoComp) + arg := createMockArgMetaHeaderFactoryArgument(coreComp, cryptoComp) arg.HeaderSigVerifier = nil imh, err := NewInterceptedMetaHeaderDataFactory(arg) @@ -169,7 +201,7 @@ func TestNewInterceptedMetaHeaderDataFactory_NilHeaderIntegrityVerifierShouldErr t.Parallel() coreComp, cryptoComp := createMockComponentHolders() - arg := createMockArgument(coreComp, cryptoComp) + arg := createMockArgMetaHeaderFactoryArgument(coreComp, cryptoComp) arg.HeaderIntegrityVerifier = nil imh, err := NewInterceptedMetaHeaderDataFactory(arg) @@ -181,7 +213,7 @@ func TestNewInterceptedMetaHeaderDataFactory_NilShardCoordinatorShouldErr(t *tes t.Parallel() coreComp, cryptoComp := createMockComponentHolders() - arg := createMockArgument(coreComp, cryptoComp) + arg := createMockArgMetaHeaderFactoryArgument(coreComp, cryptoComp) arg.ShardCoordinator = nil imh, err := NewInterceptedMetaHeaderDataFactory(arg) @@ -196,7 +228,7 @@ func TestNewInterceptedMetaHeaderDataFactory_NilChainIdShouldErr(t *testing.T) { coreComp.ChainIdCalled = func() string { return "" } - arg := createMockArgument(coreComp, cryptoComp) + arg := createMockArgMetaHeaderFactoryArgument(coreComp, cryptoComp) imh, err := NewInterceptedMetaHeaderDataFactory(arg) assert.True(t, check.IfNil(imh)) @@ -207,7 +239,7 @@ func TestNewInterceptedMetaHeaderDataFactory_NilValidityAttesterShouldErr(t *tes t.Parallel() coreComp, cryptoComp := createMockComponentHolders() - arg := createMockArgument(coreComp, cryptoComp) + arg := createMockArgMetaHeaderFactoryArgument(coreComp, cryptoComp) arg.ValidityAttester = nil imh, err := NewInterceptedMetaHeaderDataFactory(arg) @@ -219,7 +251,7 @@ func TestNewInterceptedMetaHeaderDataFactory_ShouldWorkAndCreate(t *testing.T) { t.Parallel() coreComp, cryptoComp := createMockComponentHolders() - arg := createMockArgument(coreComp, cryptoComp) + arg := createMockArgMetaHeaderFactoryArgument(coreComp, cryptoComp) imh, err := NewInterceptedMetaHeaderDataFactory(arg) assert.False(t, check.IfNil(imh)) @@ -229,7 +261,7 @@ func TestNewInterceptedMetaHeaderDataFactory_ShouldWorkAndCreate(t *testing.T) { marshalizer := &mock.MarshalizerMock{} emptyMetaHeader := &block.Header{} emptyMetaHeaderBuff, _ := marshalizer.Marshal(emptyMetaHeader) - interceptedData, err := imh.Create(emptyMetaHeaderBuff) + interceptedData, err := imh.Create(emptyMetaHeaderBuff, "") assert.Nil(t, err) _, ok := interceptedData.(*interceptedBlocks.InterceptedMetaHeader) diff --git a/process/interceptors/factory/interceptedMiniblockDataFactory.go b/process/interceptors/factory/interceptedMiniblockDataFactory.go index 64315159197..c51c78dac16 100644 --- a/process/interceptors/factory/interceptedMiniblockDataFactory.go +++ b/process/interceptors/factory/interceptedMiniblockDataFactory.go @@ -1,6 +1,7 @@ package factory import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" @@ -43,7 +44,7 @@ func NewInterceptedMiniblockDataFactory(argument *ArgInterceptedDataFactory) (*i } // Create creates instances of InterceptedData by unmarshalling provided buffer -func (imfd *interceptedMiniblockDataFactory) Create(buff []byte) (process.InterceptedData, error) { +func (imfd *interceptedMiniblockDataFactory) Create(buff []byte, _ core.PeerID) (process.InterceptedData, error) { arg := &interceptedBlocks.ArgInterceptedMiniblock{ MiniblockBuff: buff, Marshalizer: imfd.marshalizer, diff --git a/process/interceptors/factory/interceptedMiniblockDataFactory_test.go b/process/interceptors/factory/interceptedMiniblockDataFactory_test.go index 45221895a46..3a15d006751 100644 --- a/process/interceptors/factory/interceptedMiniblockDataFactory_test.go +++ b/process/interceptors/factory/interceptedMiniblockDataFactory_test.go @@ -69,7 +69,7 @@ func TestInterceptedMiniblockDataFactory_ShouldWorkAndCreate(t *testing.T) { marshalizer := &mock.MarshalizerMock{} emptyBlockBody := &block.Body{} emptyBlockBodyBuff, _ := marshalizer.Marshal(emptyBlockBody) - interceptedData, err := imdf.Create(emptyBlockBodyBuff) + interceptedData, err := imdf.Create(emptyBlockBodyBuff, "") assert.Nil(t, err) _, ok := interceptedData.(*interceptedBlocks.InterceptedMiniblock) diff --git a/process/interceptors/factory/interceptedPeerAuthenticationDataFactory.go b/process/interceptors/factory/interceptedPeerAuthenticationDataFactory.go index bbf16036170..18b4a4f40a2 100644 --- a/process/interceptors/factory/interceptedPeerAuthenticationDataFactory.go +++ b/process/interceptors/factory/interceptedPeerAuthenticationDataFactory.go @@ -3,6 +3,7 @@ package factory import ( "fmt" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" @@ -71,7 +72,7 @@ func checkArgInterceptedDataFactory(args ArgInterceptedDataFactory) error { } // Create creates instances of InterceptedData by unmarshalling provided buffer -func (ipadf *interceptedPeerAuthenticationDataFactory) Create(buff []byte) (process.InterceptedData, error) { +func (ipadf *interceptedPeerAuthenticationDataFactory) Create(buff []byte, _ core.PeerID) (process.InterceptedData, error) { arg := heartbeat.ArgInterceptedPeerAuthentication{ ArgBaseInterceptedHeartbeat: heartbeat.ArgBaseInterceptedHeartbeat{ DataBuff: buff, diff --git a/process/interceptors/factory/interceptedPeerAuthenticationDataFactory_test.go b/process/interceptors/factory/interceptedPeerAuthenticationDataFactory_test.go index bcb9490613c..d1de48a25ed 100644 --- a/process/interceptors/factory/interceptedPeerAuthenticationDataFactory_test.go +++ b/process/interceptors/factory/interceptedPeerAuthenticationDataFactory_test.go @@ -121,7 +121,7 @@ func TestNewInterceptedPeerAuthenticationDataFactory(t *testing.T) { marshaledPeerAuthentication, err := marshaller.Marshal(peerAuthentication) assert.Nil(t, err) - interceptedData, err := ipadf.Create(marshaledPeerAuthentication) + interceptedData, err := ipadf.Create(marshaledPeerAuthentication, "") assert.NotNil(t, interceptedData) assert.Nil(t, err) assert.True(t, strings.Contains(fmt.Sprintf("%T", interceptedData), "*heartbeat.interceptedPeerAuthentication")) diff --git a/process/interceptors/factory/interceptedPeerShardFactory.go b/process/interceptors/factory/interceptedPeerShardFactory.go index 6d9dd91075f..3234bb89681 100644 --- a/process/interceptors/factory/interceptedPeerShardFactory.go +++ b/process/interceptors/factory/interceptedPeerShardFactory.go @@ -1,6 +1,7 @@ package factory import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/process" @@ -41,7 +42,7 @@ func checkInterceptedDirectConnectionInfoFactoryArgs(args ArgInterceptedDataFact } // Create creates instances of InterceptedData by unmarshalling provided buffer -func (ipsf *interceptedPeerShardFactory) Create(buff []byte) (process.InterceptedData, error) { +func (ipsf *interceptedPeerShardFactory) Create(buff []byte, _ core.PeerID) (process.InterceptedData, error) { args := p2p.ArgInterceptedPeerShard{ Marshaller: ipsf.marshaller, DataBuff: buff, diff --git a/process/interceptors/factory/interceptedPeerShardFactory_test.go b/process/interceptors/factory/interceptedPeerShardFactory_test.go index 497326be5eb..797a5109113 100644 --- a/process/interceptors/factory/interceptedPeerShardFactory_test.go +++ b/process/interceptors/factory/interceptedPeerShardFactory_test.go @@ -60,7 +60,7 @@ func TestNewInterceptedPeerShardFactory(t *testing.T) { ShardId: "5", } msgBuff, _ := arg.CoreComponents.InternalMarshalizer().Marshal(msg) - interceptedData, err := idcif.Create(msgBuff) + interceptedData, err := idcif.Create(msgBuff, "") assert.Nil(t, err) assert.False(t, check.IfNil(interceptedData)) assert.True(t, strings.Contains(fmt.Sprintf("%T", interceptedData), "*p2p.interceptedPeerShard")) diff --git a/process/interceptors/factory/interceptedRewardTxDataFactory.go b/process/interceptors/factory/interceptedRewardTxDataFactory.go index ba3285a3c05..1ceec65e05f 100644 --- a/process/interceptors/factory/interceptedRewardTxDataFactory.go +++ b/process/interceptors/factory/interceptedRewardTxDataFactory.go @@ -52,7 +52,7 @@ func NewInterceptedRewardTxDataFactory(argument *ArgInterceptedDataFactory) (*in } // Create creates instances of InterceptedData by unmarshalling provided buffer -func (irtdf *interceptedRewardTxDataFactory) Create(buff []byte) (process.InterceptedData, error) { +func (irtdf *interceptedRewardTxDataFactory) Create(buff []byte, _ core.PeerID) (process.InterceptedData, error) { return rewardTransaction.NewInterceptedRewardTransaction( buff, irtdf.protoMarshalizer, diff --git a/process/interceptors/factory/interceptedRewardTxDataFactory_test.go b/process/interceptors/factory/interceptedRewardTxDataFactory_test.go index 86200f39e57..da7971d86b2 100644 --- a/process/interceptors/factory/interceptedRewardTxDataFactory_test.go +++ b/process/interceptors/factory/interceptedRewardTxDataFactory_test.go @@ -93,7 +93,7 @@ func TestInterceptedRewardTxDataFactory_ShouldWorkAndCreate(t *testing.T) { marshalizer := &mock.MarshalizerMock{} emptyRewardTx := &rewardTx.RewardTx{} emptyRewardTxBuff, _ := marshalizer.Marshal(emptyRewardTx) - interceptedData, err := imh.Create(emptyRewardTxBuff) + interceptedData, err := imh.Create(emptyRewardTxBuff, "") assert.Nil(t, err) _, ok := interceptedData.(*rewardTransaction.InterceptedRewardTransaction) diff --git a/process/interceptors/factory/interceptedShardHeaderDataFactory.go b/process/interceptors/factory/interceptedShardHeaderDataFactory.go index fd19194dbd0..a2a52db7594 100644 --- a/process/interceptors/factory/interceptedShardHeaderDataFactory.go +++ b/process/interceptors/factory/interceptedShardHeaderDataFactory.go @@ -1,9 +1,12 @@ package factory import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/sharding" @@ -12,13 +15,15 @@ import ( var _ process.InterceptedDataFactory = (*interceptedShardHeaderDataFactory)(nil) type interceptedShardHeaderDataFactory struct { - marshalizer marshal.Marshalizer - hasher hashing.Hasher - shardCoordinator sharding.Coordinator - headerSigVerifier process.InterceptedHeaderSigVerifier - headerIntegrityVerifier process.HeaderIntegrityVerifier - validityAttester process.ValidityAttester - epochStartTrigger process.EpochStartTriggerHandler + marshalizer marshal.Marshalizer + hasher hashing.Hasher + shardCoordinator sharding.Coordinator + headerSigVerifier process.InterceptedHeaderSigVerifier + headerIntegrityVerifier process.HeaderIntegrityVerifier + validityAttester process.ValidityAttester + epochStartTrigger process.EpochStartTriggerHandler + enableEpochsHandler common.EnableEpochsHandler + epochChangeGracePeriodHandler common.EpochChangeGracePeriodHandler } // NewInterceptedShardHeaderDataFactory creates an instance of interceptedShardHeaderDataFactory @@ -58,27 +63,31 @@ func NewInterceptedShardHeaderDataFactory(argument *ArgInterceptedDataFactory) ( } return &interceptedShardHeaderDataFactory{ - marshalizer: argument.CoreComponents.InternalMarshalizer(), - hasher: argument.CoreComponents.Hasher(), - shardCoordinator: argument.ShardCoordinator, - headerSigVerifier: argument.HeaderSigVerifier, - headerIntegrityVerifier: argument.HeaderIntegrityVerifier, - validityAttester: argument.ValidityAttester, - epochStartTrigger: argument.EpochStartTrigger, + marshalizer: argument.CoreComponents.InternalMarshalizer(), + hasher: argument.CoreComponents.Hasher(), + shardCoordinator: argument.ShardCoordinator, + headerSigVerifier: argument.HeaderSigVerifier, + headerIntegrityVerifier: argument.HeaderIntegrityVerifier, + validityAttester: argument.ValidityAttester, + epochStartTrigger: argument.EpochStartTrigger, + enableEpochsHandler: argument.CoreComponents.EnableEpochsHandler(), + epochChangeGracePeriodHandler: argument.CoreComponents.EpochChangeGracePeriodHandler(), }, nil } // Create creates instances of InterceptedData by unmarshalling provided buffer -func (ishdf *interceptedShardHeaderDataFactory) Create(buff []byte) (process.InterceptedData, error) { +func (ishdf *interceptedShardHeaderDataFactory) Create(buff []byte, _ core.PeerID) (process.InterceptedData, error) { arg := &interceptedBlocks.ArgInterceptedBlockHeader{ - HdrBuff: buff, - Marshalizer: ishdf.marshalizer, - Hasher: ishdf.hasher, - ShardCoordinator: ishdf.shardCoordinator, - HeaderSigVerifier: ishdf.headerSigVerifier, - HeaderIntegrityVerifier: ishdf.headerIntegrityVerifier, - ValidityAttester: ishdf.validityAttester, - EpochStartTrigger: ishdf.epochStartTrigger, + HdrBuff: buff, + Marshalizer: ishdf.marshalizer, + Hasher: ishdf.hasher, + ShardCoordinator: ishdf.shardCoordinator, + HeaderSigVerifier: ishdf.headerSigVerifier, + HeaderIntegrityVerifier: ishdf.headerIntegrityVerifier, + ValidityAttester: ishdf.validityAttester, + EpochStartTrigger: ishdf.epochStartTrigger, + EnableEpochsHandler: ishdf.enableEpochsHandler, + EpochChangeGracePeriodHandler: ishdf.epochChangeGracePeriodHandler, } return interceptedBlocks.NewInterceptedHeader(arg) diff --git a/process/interceptors/factory/interceptedShardHeaderDataFactory_test.go b/process/interceptors/factory/interceptedShardHeaderDataFactory_test.go index 327be59018c..31adf2802a1 100644 --- a/process/interceptors/factory/interceptedShardHeaderDataFactory_test.go +++ b/process/interceptors/factory/interceptedShardHeaderDataFactory_test.go @@ -106,7 +106,7 @@ func TestInterceptedShardHeaderDataFactory_ShouldWorkAndCreate(t *testing.T) { marshalizer := &mock.MarshalizerMock{} emptyBlockHeader := &block.Header{} emptyBlockHeaderBuff, _ := marshalizer.Marshal(emptyBlockHeader) - interceptedData, err := imh.Create(emptyBlockHeaderBuff) + interceptedData, err := imh.Create(emptyBlockHeaderBuff, "") assert.Nil(t, err) _, ok := interceptedData.(*interceptedBlocks.InterceptedHeader) diff --git a/process/interceptors/factory/interceptedTrieNodeDataFactory.go b/process/interceptors/factory/interceptedTrieNodeDataFactory.go index bd204d45c76..33e09286487 100644 --- a/process/interceptors/factory/interceptedTrieNodeDataFactory.go +++ b/process/interceptors/factory/interceptedTrieNodeDataFactory.go @@ -1,6 +1,7 @@ package factory import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-go/process" @@ -34,7 +35,7 @@ func NewInterceptedTrieNodeDataFactory( } // Create creates instances of InterceptedData by unmarshalling provided buffer -func (sidf *interceptedTrieNodeDataFactory) Create(buff []byte) (process.InterceptedData, error) { +func (sidf *interceptedTrieNodeDataFactory) Create(buff []byte, _ core.PeerID) (process.InterceptedData, error) { return trie.NewInterceptedTrieNode(buff, sidf.hasher) } diff --git a/process/interceptors/factory/interceptedTxDataFactory.go b/process/interceptors/factory/interceptedTxDataFactory.go index 0e1a568ad53..65fe2f69e7c 100644 --- a/process/interceptors/factory/interceptedTxDataFactory.go +++ b/process/interceptors/factory/interceptedTxDataFactory.go @@ -113,7 +113,7 @@ func NewInterceptedTxDataFactory(argument *ArgInterceptedDataFactory) (*intercep } // Create creates instances of InterceptedData by unmarshalling provided buffer -func (itdf *interceptedTxDataFactory) Create(buff []byte) (process.InterceptedData, error) { +func (itdf *interceptedTxDataFactory) Create(buff []byte, _ core.PeerID) (process.InterceptedData, error) { return transaction.NewInterceptedTransaction( buff, itdf.protoMarshalizer, diff --git a/process/interceptors/factory/interceptedTxDataFactory_test.go b/process/interceptors/factory/interceptedTxDataFactory_test.go index c4b8321728a..56efc31b681 100644 --- a/process/interceptors/factory/interceptedTxDataFactory_test.go +++ b/process/interceptors/factory/interceptedTxDataFactory_test.go @@ -196,7 +196,7 @@ func TestInterceptedTxDataFactory_ShouldWorkAndCreate(t *testing.T) { Value: big.NewInt(0), } emptyTxBuff, _ := marshalizer.Marshal(emptyTx) - interceptedData, err := imh.Create(emptyTxBuff) + interceptedData, err := imh.Create(emptyTxBuff, "") assert.Nil(t, err) _, ok := interceptedData.(*transaction.InterceptedTransaction) diff --git a/process/interceptors/factory/interceptedUnsignedTxDataFactory.go b/process/interceptors/factory/interceptedUnsignedTxDataFactory.go index 4887893b26d..44233ae1ef6 100644 --- a/process/interceptors/factory/interceptedUnsignedTxDataFactory.go +++ b/process/interceptors/factory/interceptedUnsignedTxDataFactory.go @@ -52,7 +52,7 @@ func NewInterceptedUnsignedTxDataFactory(argument *ArgInterceptedDataFactory) (* } // Create creates instances of InterceptedData by unmarshalling provided buffer -func (iutdf *interceptedUnsignedTxDataFactory) Create(buff []byte) (process.InterceptedData, error) { +func (iutdf *interceptedUnsignedTxDataFactory) Create(buff []byte, _ core.PeerID) (process.InterceptedData, error) { return unsigned.NewInterceptedUnsignedTransaction( buff, iutdf.protoMarshalizer, diff --git a/process/interceptors/factory/interceptedUnsignedTxDataFactory_test.go b/process/interceptors/factory/interceptedUnsignedTxDataFactory_test.go index 85da0ca5664..41ab63596b1 100644 --- a/process/interceptors/factory/interceptedUnsignedTxDataFactory_test.go +++ b/process/interceptors/factory/interceptedUnsignedTxDataFactory_test.go @@ -93,7 +93,7 @@ func TestInterceptedUnsignedTxDataFactory_ShouldWorkAndCreate(t *testing.T) { marshalizer := &mock.MarshalizerMock{} emptyTx := &smartContractResult.SmartContractResult{} emptyTxBuff, _ := marshalizer.Marshal(emptyTx) - interceptedData, err := imh.Create(emptyTxBuff) + interceptedData, err := imh.Create(emptyTxBuff, "") assert.Nil(t, err) _, ok := interceptedData.(*unsigned.InterceptedUnsignedTransaction) diff --git a/process/interceptors/factory/interceptedValidatorInfoDataFactory.go b/process/interceptors/factory/interceptedValidatorInfoDataFactory.go index b6135c2a6a0..0ae55db3767 100644 --- a/process/interceptors/factory/interceptedValidatorInfoDataFactory.go +++ b/process/interceptors/factory/interceptedValidatorInfoDataFactory.go @@ -1,6 +1,7 @@ package factory import ( + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" @@ -41,7 +42,7 @@ func checkInterceptedValidatorInfoDataFactoryArgs(args ArgInterceptedDataFactory } // Create creates instances of InterceptedData by unmarshalling provided buffer -func (ividf *interceptedValidatorInfoDataFactory) Create(buff []byte) (process.InterceptedData, error) { +func (ividf *interceptedValidatorInfoDataFactory) Create(buff []byte, _ core.PeerID) (process.InterceptedData, error) { args := peer.ArgInterceptedValidatorInfo{ DataBuff: buff, Marshalizer: ividf.marshaller, diff --git a/process/interceptors/factory/interceptedValidatorInfoDataFactory_test.go b/process/interceptors/factory/interceptedValidatorInfoDataFactory_test.go index a46f327c4f3..a6f1d9772d6 100644 --- a/process/interceptors/factory/interceptedValidatorInfoDataFactory_test.go +++ b/process/interceptors/factory/interceptedValidatorInfoDataFactory_test.go @@ -79,7 +79,7 @@ func TestInterceptedValidatorInfoDataFactory_Create(t *testing.T) { ividf, _ := NewInterceptedValidatorInfoDataFactory(*createMockArgument(createMockComponentHolders())) require.False(t, check.IfNil(ividf)) - ivi, err := ividf.Create(nil) + ivi, err := ividf.Create(nil, "") assert.NotNil(t, err) assert.True(t, check.IfNil(ivi)) }) @@ -88,7 +88,7 @@ func TestInterceptedValidatorInfoDataFactory_Create(t *testing.T) { ividf, _ := NewInterceptedValidatorInfoDataFactory(*createMockArgument(createMockComponentHolders())) require.False(t, check.IfNil(ividf)) - ivi, err := ividf.Create(createMockValidatorInfoBuff()) + ivi, err := ividf.Create(createMockValidatorInfoBuff(), "") assert.Nil(t, err) assert.False(t, check.IfNil(ivi)) }) diff --git a/process/interceptors/interceptedDataVerifier.go b/process/interceptors/interceptedDataVerifier.go new file mode 100644 index 00000000000..7f230b24ae2 --- /dev/null +++ b/process/interceptors/interceptedDataVerifier.go @@ -0,0 +1,70 @@ +package interceptors + +import ( + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/core/sync" + + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/storage" +) + +type interceptedDataStatus int8 + +const ( + validInterceptedData interceptedDataStatus = iota + interceptedDataStatusBytesSize = 8 +) + +type interceptedDataVerifier struct { + km sync.KeyRWMutexHandler + cache storage.Cacher +} + +// NewInterceptedDataVerifier creates a new instance of intercepted data verifier +func NewInterceptedDataVerifier(cache storage.Cacher) (*interceptedDataVerifier, error) { + if check.IfNil(cache) { + return nil, process.ErrNilInterceptedDataCache + } + + return &interceptedDataVerifier{ + km: sync.NewKeyRWMutex(), + cache: cache, + }, nil +} + +// Verify will check if the intercepted data has been validated before and put in the time cache. +// It will retrieve the status in the cache if it exists, otherwise it will validate it and store the status of the +// validation in the cache. Note that the entries are stored for a set period of time +func (idv *interceptedDataVerifier) Verify(interceptedData process.InterceptedData) error { + if len(interceptedData.Hash()) == 0 { + return interceptedData.CheckValidity() + } + + hash := string(interceptedData.Hash()) + idv.km.Lock(hash) + defer idv.km.Unlock(hash) + + if val, ok := idv.cache.Get(interceptedData.Hash()); ok { + if val == validInterceptedData { + return nil + } + + return process.ErrInvalidInterceptedData + } + + err := interceptedData.CheckValidity() + if err != nil { + log.Debug("Intercepted data is invalid", "hash", interceptedData.Hash(), "err", err) + // TODO: investigate to selectively add as invalid intercepted data only when data is indeed invalid instead of missing + // idv.cache.Put(interceptedData.Hash(), invalidInterceptedData, interceptedDataStatusBytesSize) + return process.ErrInvalidInterceptedData + } + + idv.cache.Put(interceptedData.Hash(), validInterceptedData, interceptedDataStatusBytesSize) + return nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (idv *interceptedDataVerifier) IsInterfaceNil() bool { + return idv == nil +} diff --git a/process/interceptors/interceptedDataVerifier_test.go b/process/interceptors/interceptedDataVerifier_test.go new file mode 100644 index 00000000000..503eb790f32 --- /dev/null +++ b/process/interceptors/interceptedDataVerifier_test.go @@ -0,0 +1,236 @@ +package interceptors + +import ( + "sync" + "testing" + "time" + + "github.com/multiversx/mx-chain-core-go/core/atomic" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/storage" + "github.com/multiversx/mx-chain-go/storage/cache" + "github.com/multiversx/mx-chain-go/testscommon" +) + +const defaultSpan = 1 * time.Second + +func defaultInterceptedDataVerifier(span time.Duration) *interceptedDataVerifier { + c, _ := cache.NewTimeCacher(cache.ArgTimeCacher{ + DefaultSpan: span, + CacheExpiry: span, + }) + + verifier, _ := NewInterceptedDataVerifier(c) + return verifier +} + +func TestNewInterceptedDataVerifier(t *testing.T) { + t.Parallel() + + var c storage.Cacher + verifier, err := NewInterceptedDataVerifier(c) + require.Equal(t, process.ErrNilInterceptedDataCache, err) + require.Nil(t, verifier) +} + +func TestInterceptedDataVerifier_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var verifier *interceptedDataVerifier + require.True(t, verifier.IsInterfaceNil()) + + verifier = defaultInterceptedDataVerifier(defaultSpan) + require.False(t, verifier.IsInterfaceNil()) +} + +func TestInterceptedDataVerifier_EmptyHash(t *testing.T) { + t.Parallel() + + var checkValidityCounter int + verifier := defaultInterceptedDataVerifier(defaultSpan) + interceptedData := &testscommon.InterceptedDataStub{ + CheckValidityCalled: func() error { + checkValidityCounter++ + return nil + }, + IsForCurrentShardCalled: func() bool { + return false + }, + HashCalled: func() []byte { + return nil + }, + } + + err := verifier.Verify(interceptedData) + require.NoError(t, err) + require.Equal(t, 1, checkValidityCounter) + + err = verifier.Verify(interceptedData) + require.NoError(t, err) + require.Equal(t, 2, checkValidityCounter) +} + +func TestInterceptedDataVerifier_CheckValidityShouldWork(t *testing.T) { + t.Parallel() + + checkValidityCounter := atomic.Counter{} + + interceptedData := &testscommon.InterceptedDataStub{ + CheckValidityCalled: func() error { + checkValidityCounter.Add(1) + return nil + }, + IsForCurrentShardCalled: func() bool { + return false + }, + HashCalled: func() []byte { + return []byte("hash") + }, + } + + verifier := defaultInterceptedDataVerifier(defaultSpan) + + err := verifier.Verify(interceptedData) + require.NoError(t, err) + + errCount := atomic.Counter{} + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := verifier.Verify(interceptedData) + if err != nil { + errCount.Add(1) + } + }() + } + wg.Wait() + + require.Equal(t, int64(0), errCount.Get()) + require.Equal(t, int64(1), checkValidityCounter.Get()) +} + +func TestInterceptedDataVerifier_CheckValidityShouldNotWork(t *testing.T) { + t.Parallel() + + checkValidityCounter := atomic.Counter{} + interceptedData := &testscommon.InterceptedDataStub{ + CheckValidityCalled: func() error { + checkValidityCounter.Add(1) + return process.ErrInvalidInterceptedData + }, + IsForCurrentShardCalled: func() bool { + return false + }, + HashCalled: func() []byte { + return []byte("hash") + }, + } + + verifier := defaultInterceptedDataVerifier(defaultSpan) + + err := verifier.Verify(interceptedData) + require.Equal(t, process.ErrInvalidInterceptedData, err) + + errCount := atomic.Counter{} + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := verifier.Verify(interceptedData) + if err != nil { + errCount.Add(1) + } + }() + } + wg.Wait() + + require.Equal(t, int64(100), errCount.Get()) + require.Equal(t, int64(101), checkValidityCounter.Get()) +} + +func TestInterceptedDataVerifier_CheckExpiryTime(t *testing.T) { + t.Parallel() + + t.Run("expiry on valid data", func(t *testing.T) { + expiryTestDuration := defaultSpan * 2 + + checkValidityCounter := atomic.Counter{} + + interceptedData := &testscommon.InterceptedDataStub{ + CheckValidityCalled: func() error { + checkValidityCounter.Add(1) + return nil + }, + IsForCurrentShardCalled: func() bool { + return false + }, + HashCalled: func() []byte { + return []byte("hash") + }, + } + + verifier := defaultInterceptedDataVerifier(expiryTestDuration) + + // First retrieval, check validity is reached. + err := verifier.Verify(interceptedData) + require.NoError(t, err) + require.Equal(t, int64(1), checkValidityCounter.Get()) + + // Second retrieval should be from the cache. + err = verifier.Verify(interceptedData) + require.NoError(t, err) + require.Equal(t, int64(1), checkValidityCounter.Get()) + + // Wait for the cache expiry + <-time.After(expiryTestDuration + 100*time.Millisecond) + + // Third retrieval should reach validity check again. + err = verifier.Verify(interceptedData) + require.NoError(t, err) + require.Equal(t, int64(2), checkValidityCounter.Get()) + }) + + t.Run("expiry on invalid data", func(t *testing.T) { + expiryTestDuration := defaultSpan * 2 + + checkValidityCounter := atomic.Counter{} + + interceptedData := &testscommon.InterceptedDataStub{ + CheckValidityCalled: func() error { + checkValidityCounter.Add(1) + return process.ErrInvalidInterceptedData + }, + IsForCurrentShardCalled: func() bool { + return false + }, + HashCalled: func() []byte { + return []byte("hash") + }, + } + + verifier := defaultInterceptedDataVerifier(expiryTestDuration) + + // First retrieval, check validity is reached. + err := verifier.Verify(interceptedData) + require.Equal(t, process.ErrInvalidInterceptedData, err) + require.Equal(t, int64(1), checkValidityCounter.Get()) + + // Second retrieval + err = verifier.Verify(interceptedData) + require.Equal(t, process.ErrInvalidInterceptedData, err) + require.Equal(t, int64(2), checkValidityCounter.Get()) + + // Wait for the cache expiry + <-time.After(expiryTestDuration + 100*time.Millisecond) + + // Third retrieval should reach validity check again. + err = verifier.Verify(interceptedData) + require.Equal(t, process.ErrInvalidInterceptedData, err) + require.Equal(t, int64(3), checkValidityCounter.Get()) + }) +} diff --git a/process/interceptors/multiDataInterceptor.go b/process/interceptors/multiDataInterceptor.go index 9e0197ea741..76b33046b03 100644 --- a/process/interceptors/multiDataInterceptor.go +++ b/process/interceptors/multiDataInterceptor.go @@ -6,34 +6,40 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" + "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/pkg/errors" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/debug/handler" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors/disabled" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("process/interceptors") // ArgMultiDataInterceptor is the argument for the multi-data interceptor type ArgMultiDataInterceptor struct { - Topic string - Marshalizer marshal.Marshalizer - DataFactory process.InterceptedDataFactory - Processor process.InterceptorProcessor - Throttler process.InterceptorThrottler - AntifloodHandler process.P2PAntifloodHandler - WhiteListRequest process.WhiteListHandler - PreferredPeersHolder process.PreferredPeersHolderHandler - CurrentPeerId core.PeerID + Topic string + Marshalizer marshal.Marshalizer + Hasher hashing.Hasher + DataFactory process.InterceptedDataFactory + Processor process.InterceptorProcessor + Throttler process.InterceptorThrottler + AntifloodHandler process.P2PAntifloodHandler + WhiteListRequest process.WhiteListHandler + PreferredPeersHolder process.PreferredPeersHolderHandler + CurrentPeerId core.PeerID + InterceptedDataVerifier process.InterceptedDataVerifier } // MultiDataInterceptor is used for intercepting packed multi data type MultiDataInterceptor struct { *baseDataInterceptor marshalizer marshal.Marshalizer + hasher hashing.Hasher factory process.InterceptedDataFactory whiteListRequest process.WhiteListHandler mutChunksProcessor sync.RWMutex @@ -48,6 +54,9 @@ func NewMultiDataInterceptor(arg ArgMultiDataInterceptor) (*MultiDataInterceptor if check.IfNil(arg.Marshalizer) { return nil, process.ErrNilMarshalizer } + if check.IfNil(arg.Hasher) { + return nil, process.ErrNilHasher + } if check.IfNil(arg.DataFactory) { return nil, process.ErrNilInterceptedDataFactory } @@ -66,21 +75,26 @@ func NewMultiDataInterceptor(arg ArgMultiDataInterceptor) (*MultiDataInterceptor if check.IfNil(arg.PreferredPeersHolder) { return nil, process.ErrNilPreferredPeersHolder } + if check.IfNil(arg.InterceptedDataVerifier) { + return nil, process.ErrNilInterceptedDataVerifier + } if len(arg.CurrentPeerId) == 0 { return nil, process.ErrEmptyPeerID } multiDataIntercept := &MultiDataInterceptor{ baseDataInterceptor: &baseDataInterceptor{ - throttler: arg.Throttler, - antifloodHandler: arg.AntifloodHandler, - topic: arg.Topic, - currentPeerId: arg.CurrentPeerId, - processor: arg.Processor, - preferredPeersHolder: arg.PreferredPeersHolder, - debugHandler: handler.NewDisabledInterceptorDebugHandler(), + throttler: arg.Throttler, + antifloodHandler: arg.AntifloodHandler, + topic: arg.Topic, + currentPeerId: arg.CurrentPeerId, + processor: arg.Processor, + preferredPeersHolder: arg.PreferredPeersHolder, + debugHandler: handler.NewDisabledInterceptorDebugHandler(), + interceptedDataVerifier: arg.InterceptedDataVerifier, }, marshalizer: arg.Marshalizer, + hasher: arg.Hasher, factory: arg.DataFactory, whiteListRequest: arg.WhiteListRequest, chunksProcessor: disabled.NewDisabledInterceptedChunksProcessor(), @@ -91,10 +105,10 @@ func NewMultiDataInterceptor(arg ArgMultiDataInterceptor) (*MultiDataInterceptor // ProcessReceivedMessage is the callback func from the p2p.Messenger and will be called each time a new message was received // (for the topic this validator was registered to) -func (mdi *MultiDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, _ p2p.MessageHandler) error { +func (mdi *MultiDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, _ p2p.MessageHandler) ([]byte, error) { err := mdi.preProcessMesage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } b := batch.Batch{} @@ -107,13 +121,13 @@ func (mdi *MultiDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, mdi.antifloodHandler.BlacklistPeer(message.Peer(), reason, common.InvalidMessageBlacklistDuration) mdi.antifloodHandler.BlacklistPeer(fromConnectedPeer, reason, common.InvalidMessageBlacklistDuration) - return err + return nil, err } multiDataBuff := b.Data lenMultiData := len(multiDataBuff) if lenMultiData == 0 { mdi.throttler.EndProcessing() - return process.ErrNoDataInMessage + return nil, process.ErrNoDataInMessage } err = mdi.antifloodHandler.CanProcessMessagesOnTopic( @@ -125,7 +139,7 @@ func (mdi *MultiDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, ) if err != nil { mdi.throttler.EndProcessing() - return err + return nil, err } mdi.mutChunksProcessor.RLock() @@ -133,13 +147,13 @@ func (mdi *MultiDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, mdi.mutChunksProcessor.RUnlock() if err != nil { mdi.throttler.EndProcessing() - return err + return nil, err } isIncompleteChunk := checkChunksRes.IsChunk && !checkChunksRes.HaveAllChunks if isIncompleteChunk { mdi.throttler.EndProcessing() - return nil + return nil, nil } isCompleteChunk := checkChunksRes.IsChunk && checkChunksRes.HaveAllChunks if isCompleteChunk { @@ -153,9 +167,10 @@ func (mdi *MultiDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, var interceptedData process.InterceptedData interceptedData, err = mdi.interceptedData(dataBuff, message.Peer(), fromConnectedPeer) listInterceptedData[index] = interceptedData + if err != nil { mdi.throttler.EndProcessing() - return err + return nil, err } isWhiteListed := mdi.whiteListRequest.IsWhiteListed(interceptedData) @@ -165,7 +180,7 @@ func (mdi *MultiDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, p2p.PeerIdToShortString(message.Peer()), "topic", mdi.topic, "err", errOriginator) - return errOriginator + return nil, errOriginator } isForCurrentShard := interceptedData.IsForCurrentShard() @@ -180,7 +195,7 @@ func (mdi *MultiDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, "is white listed", isWhiteListed, ) mdi.throttler.EndProcessing() - return process.ErrInterceptedDataNotForCurrentShard + return nil, process.ErrInterceptedDataNotForCurrentShard } } @@ -191,11 +206,33 @@ func (mdi *MultiDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, mdi.throttler.EndProcessing() }() - return nil + messageID := mdi.createInterceptedMultiDataMsgID(listInterceptedData) + + return messageID, nil +} + +func (mdi *MultiDataInterceptor) createInterceptedMultiDataMsgID(interceptedMultiData []process.InterceptedData) []byte { + if len(interceptedMultiData) == 0 { + return nil + } + if len(interceptedMultiData) == 1 { + return interceptedMultiData[0].Hash() + } + + lenOneID := len(interceptedMultiData[0].Hash()) + data := make([]byte, 0, lenOneID*len(interceptedMultiData)) + for _, id := range interceptedMultiData { + data = append(data, id.Hash()...) + } + if len(data) == 0 { + return []byte{} + } + + return mdi.hasher.Compute(string(data)) } func (mdi *MultiDataInterceptor) interceptedData(dataBuff []byte, originator core.PeerID, fromConnectedPeer core.PeerID) (process.InterceptedData, error) { - interceptedData, err := mdi.factory.Create(dataBuff) + interceptedData, err := mdi.factory.Create(dataBuff, originator) if err != nil { // this situation is so severe that we need to black list de peers reason := "can not create object from received bytes, topic " + mdi.topic + ", error " + err.Error() @@ -207,11 +244,11 @@ func (mdi *MultiDataInterceptor) interceptedData(dataBuff []byte, originator cor mdi.receivedDebugInterceptedData(interceptedData) - err = interceptedData.CheckValidity() + err = mdi.interceptedDataVerifier.Verify(interceptedData) if err != nil { mdi.processDebugInterceptedData(interceptedData, err) - isWrongVersion := err == process.ErrInvalidTransactionVersion || err == process.ErrInvalidChainID + isWrongVersion := errors.Is(err, process.ErrInvalidTransactionVersion) || errors.Is(err, process.ErrInvalidChainID) if isWrongVersion { // this situation is so severe that we need to black list de peers reason := "wrong version of received intercepted data, topic " + mdi.topic + ", error " + err.Error() diff --git a/process/interceptors/multiDataInterceptor_test.go b/process/interceptors/multiDataInterceptor_test.go index 6ca244409b7..3f0e303af1c 100644 --- a/process/interceptors/multiDataInterceptor_test.go +++ b/process/interceptors/multiDataInterceptor_test.go @@ -10,28 +10,35 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var fromConnectedPeerId = core.PeerID("from connected peer Id") func createMockArgMultiDataInterceptor() interceptors.ArgMultiDataInterceptor { return interceptors.ArgMultiDataInterceptor{ - Topic: "test topic", - Marshalizer: &mock.MarshalizerMock{}, - DataFactory: &mock.InterceptedDataFactoryStub{}, - Processor: &mock.InterceptorProcessorStub{}, - Throttler: createMockThrottler(), - AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, - WhiteListRequest: &testscommon.WhiteListHandlerStub{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - CurrentPeerId: "pid", + Topic: "test topic", + Marshalizer: &mock.MarshalizerMock{}, + Hasher: &mock.HasherStub{ + ComputeCalled: func(s string) []byte { + return []byte("hash") + }, + }, + DataFactory: &mock.InterceptedDataFactoryStub{}, + Processor: &mock.InterceptorProcessorStub{}, + Throttler: createMockThrottler(), + AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, + WhiteListRequest: &testscommon.WhiteListHandlerStub{}, + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + CurrentPeerId: "pid", + InterceptedDataVerifier: &mock.InterceptedDataVerifierMock{}, } } @@ -68,6 +75,17 @@ func TestNewMultiDataInterceptor_NilInterceptedDataFactoryShouldErr(t *testing.T assert.Equal(t, process.ErrNilInterceptedDataFactory, err) } +func TestNewMultiDataInterceptor_NilInterceptedDataVerifierShouldErr(t *testing.T) { + t.Parallel() + + arg := createMockArgMultiDataInterceptor() + arg.InterceptedDataVerifier = nil + mdi, err := interceptors.NewMultiDataInterceptor(arg) + + assert.True(t, check.IfNil(mdi)) + assert.Equal(t, process.ErrNilInterceptedDataVerifier, err) +} + func TestNewMultiDataInterceptor_NilInterceptedDataProcessorShouldErr(t *testing.T) { t.Parallel() @@ -145,7 +163,7 @@ func TestNewMultiDataInterceptor(t *testing.T) { assert.Equal(t, arg.Topic, mdi.Topic()) } -//------- ProcessReceivedMessage +// ------- ProcessReceivedMessage func TestMultiDataInterceptor_ProcessReceivedMessageNilMessageShouldErr(t *testing.T) { t.Parallel() @@ -153,9 +171,10 @@ func TestMultiDataInterceptor_ProcessReceivedMessageNilMessageShouldErr(t *testi arg := createMockArgMultiDataInterceptor() mdi, _ := interceptors.NewMultiDataInterceptor(arg) - err := mdi.ProcessReceivedMessage(nil, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mdi.ProcessReceivedMessage(nil, fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, process.ErrNilMessage, err) + assert.Nil(t, msgID) } func TestMultiDataInterceptor_ProcessReceivedMessageUnmarshalFailsShouldErr(t *testing.T) { @@ -188,11 +207,12 @@ func TestMultiDataInterceptor_ProcessReceivedMessageUnmarshalFailsShouldErr(t *t DataField: []byte("data to be processed"), PeerField: originatorPid, } - err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, errExpeced, err) assert.True(t, originatorBlackListed) assert.True(t, fromConnectedPeerBlackListed) + assert.Nil(t, msgID) } func TestMultiDataInterceptor_ProcessReceivedMessageUnmarshalReturnsEmptySliceShouldErr(t *testing.T) { @@ -209,9 +229,10 @@ func TestMultiDataInterceptor_ProcessReceivedMessageUnmarshalReturnsEmptySliceSh msg := &p2pmocks.P2PMessageMock{ DataField: []byte("data to be processed"), } - err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, process.ErrNoDataInMessage, err) + assert.Nil(t, msgID) } func TestMultiDataInterceptor_ProcessReceivedCreateFailsShouldErr(t *testing.T) { @@ -251,7 +272,7 @@ func TestMultiDataInterceptor_ProcessReceivedCreateFailsShouldErr(t *testing.T) DataField: dataField, PeerField: originatorPid, } - err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) @@ -262,6 +283,7 @@ func TestMultiDataInterceptor_ProcessReceivedCreateFailsShouldErr(t *testing.T) assert.Equal(t, int32(1), throttler.EndProcessingCount()) assert.True(t, originatorBlackListed) assert.True(t, fromConnectedPeerBlackListed) + assert.Nil(t, msgID) } func TestMultiDataInterceptor_ProcessReceivedPartiallyCorrectDataShouldErr(t *testing.T) { @@ -282,6 +304,7 @@ func TestMultiDataInterceptor_ProcessReceivedPartiallyCorrectDataShouldErr(t *te IsForCurrentShardCalled: func() bool { return true }, + HashCalled: func() []byte { return []byte("hash") }, } arg := createMockArgMultiDataInterceptor() arg.DataFactory = &mock.InterceptedDataFactoryStub{ @@ -301,7 +324,7 @@ func TestMultiDataInterceptor_ProcessReceivedPartiallyCorrectDataShouldErr(t *te msg := &p2pmocks.P2PMessageMock{ DataField: dataField, } - err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) @@ -310,6 +333,7 @@ func TestMultiDataInterceptor_ProcessReceivedPartiallyCorrectDataShouldErr(t *te assert.Equal(t, int32(0), atomic.LoadInt32(&processCalledNum)) assert.Equal(t, int32(1), throttler.StartProcessingCount()) assert.Equal(t, int32(1), throttler.EndProcessingCount()) + assert.Nil(t, msgID) } func TestMultiDataInterceptor_ProcessReceivedMessageNotValidShouldErrAndNotProcess(t *testing.T) { @@ -354,13 +378,18 @@ func testProcessReceiveMessageMultiData(t *testing.T, isForCurrentShard bool, ex } arg.Processor = createMockInterceptorStub(&checkCalledNum, &processCalledNum) arg.Throttler = throttler + arg.InterceptedDataVerifier = &mock.InterceptedDataVerifierMock{ + VerifyCalled: func(interceptedData process.InterceptedData) error { + return interceptedData.CheckValidity() + }, + } mdi, _ := interceptors.NewMultiDataInterceptor(arg) dataField, _ := marshalizer.Marshal(&batch.Batch{Data: buffData}) msg := &p2pmocks.P2PMessageMock{ DataField: dataField, } - err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) @@ -369,6 +398,7 @@ func testProcessReceiveMessageMultiData(t *testing.T, isForCurrentShard bool, ex assert.Equal(t, int32(calledNum), atomic.LoadInt32(&processCalledNum)) assert.Equal(t, int32(1), throttler.StartProcessingCount()) assert.Equal(t, int32(1), throttler.EndProcessingCount()) + assert.Len(t, msgID, 0) } func TestMultiDataInterceptor_ProcessReceivedMessageCheckBatchErrors(t *testing.T) { @@ -401,13 +431,14 @@ func TestMultiDataInterceptor_ProcessReceivedMessageCheckBatchErrors(t *testing. msg := &p2pmocks.P2PMessageMock{ DataField: dataField, } - err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) assert.Equal(t, expectedErr, err) assert.Equal(t, int32(1), throttler.StartProcessingCount()) assert.Equal(t, int32(1), throttler.EndProcessingCount()) + assert.Nil(t, msgID) } func TestMultiDataInterceptor_ProcessReceivedMessageCheckBatchIsIncomplete(t *testing.T) { @@ -443,13 +474,14 @@ func TestMultiDataInterceptor_ProcessReceivedMessageCheckBatchIsIncomplete(t *te msg := &p2pmocks.P2PMessageMock{ DataField: dataField, } - err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) assert.Nil(t, err) assert.Equal(t, int32(1), throttler.StartProcessingCount()) assert.Equal(t, int32(1), throttler.EndProcessingCount()) + assert.Nil(t, msgID) } func TestMultiDataInterceptor_ProcessReceivedMessageCheckBatchIsComplete(t *testing.T) { @@ -462,6 +494,7 @@ func TestMultiDataInterceptor_ProcessReceivedMessageCheckBatchIsComplete(t *test processCalledNum := int32(0) throttler := createMockThrottler() arg := createMockArgMultiDataInterceptor() + msgHash := []byte("hash") interceptedData := &testscommon.InterceptedDataStub{ CheckValidityCalled: func() error { return nil @@ -469,10 +502,11 @@ func TestMultiDataInterceptor_ProcessReceivedMessageCheckBatchIsComplete(t *test IsForCurrentShardCalled: func() bool { return true }, + HashCalled: func() []byte { return msgHash }, } arg.DataFactory = &mock.InterceptedDataFactoryStub{ CreateCalled: func(buff []byte) (data process.InterceptedData, e error) { - assert.Equal(t, newBuffData, buff) //chunk processor switched the buffer + assert.Equal(t, newBuffData, buff) // chunk processor switched the buffer createCalled = true return interceptedData, nil }, @@ -496,7 +530,7 @@ func TestMultiDataInterceptor_ProcessReceivedMessageCheckBatchIsComplete(t *test msg := &p2pmocks.P2PMessageMock{ DataField: dataField, } - err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) @@ -504,6 +538,7 @@ func TestMultiDataInterceptor_ProcessReceivedMessageCheckBatchIsComplete(t *test assert.True(t, createCalled) assert.Equal(t, int32(1), throttler.StartProcessingCount()) assert.Equal(t, int32(1), throttler.EndProcessingCount()) + assert.Equal(t, msgHash, msgID) } func TestMultiDataInterceptor_ProcessReceivedMessageWhitelistedShouldRetNil(t *testing.T) { @@ -514,6 +549,7 @@ func TestMultiDataInterceptor_ProcessReceivedMessageWhitelistedShouldRetNil(t *t checkCalledNum := int32(0) processCalledNum := int32(0) throttler := createMockThrottler() + msgHash := []byte("hash") interceptedData := &testscommon.InterceptedDataStub{ CheckValidityCalled: func() error { return nil @@ -521,6 +557,9 @@ func TestMultiDataInterceptor_ProcessReceivedMessageWhitelistedShouldRetNil(t *t IsForCurrentShardCalled: func() bool { return false }, + HashCalled: func() []byte { + return msgHash + }, } arg := createMockArgMultiDataInterceptor() arg.DataFactory = &mock.InterceptedDataFactoryStub{ @@ -541,7 +580,7 @@ func TestMultiDataInterceptor_ProcessReceivedMessageWhitelistedShouldRetNil(t *t msg := &p2pmocks.P2PMessageMock{ DataField: dataField, } - err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) @@ -550,6 +589,7 @@ func TestMultiDataInterceptor_ProcessReceivedMessageWhitelistedShouldRetNil(t *t assert.Equal(t, int32(2), atomic.LoadInt32(&processCalledNum)) assert.Equal(t, int32(1), throttler.StartProcessingCount()) assert.Equal(t, int32(1), throttler.EndProcessingCount()) + assert.Equal(t, msgHash, msgID) } func TestMultiDataInterceptor_InvalidTxVersionShouldBackList(t *testing.T) { @@ -569,7 +609,11 @@ func processReceivedMessageMultiDataInvalidVersion(t *testing.T, expectedErr err marshalizer := &mock.MarshalizerMock{} checkCalledNum := int32(0) processCalledNum := int32(0) + msgHash := []byte("hash") interceptedData := &testscommon.InterceptedDataStub{ + HashCalled: func() []byte { + return msgHash + }, CheckValidityCalled: func() error { return expectedErr }, @@ -603,6 +647,11 @@ func processReceivedMessageMultiDataInvalidVersion(t *testing.T, expectedErr err return true }, } + arg.InterceptedDataVerifier = &mock.InterceptedDataVerifierMock{ + VerifyCalled: func(interceptedData process.InterceptedData) error { + return interceptedData.CheckValidity() + }, + } mdi, _ := interceptors.NewMultiDataInterceptor(arg) dataField, _ := marshalizer.Marshal(&batch.Batch{Data: buffData}) @@ -611,13 +660,14 @@ func processReceivedMessageMultiDataInvalidVersion(t *testing.T, expectedErr err PeerField: originator, } - err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, expectedErr, err) assert.True(t, isFromConnectedPeerBlackListed) assert.True(t, isOriginatorBlackListed) + assert.Nil(t, msgID) } -//------- debug +// ------- debug func TestMultiDataInterceptor_SetInterceptedDebugHandlerNilShouldErr(t *testing.T) { t.Parallel() @@ -640,7 +690,7 @@ func TestMultiDataInterceptor_SetInterceptedDebugHandlerShouldWork(t *testing.T) err := mdi.SetInterceptedDebugHandler(debugger) assert.Nil(t, err) - assert.True(t, debugger == mdi.InterceptedDebugHandler()) //pointer testing + assert.True(t, debugger == mdi.InterceptedDebugHandler()) // pointer testing } func TestMultiDataInterceptor_ProcessReceivedMessageIsOriginatorNotOkButWhiteListed(t *testing.T) { @@ -651,6 +701,7 @@ func TestMultiDataInterceptor_ProcessReceivedMessageIsOriginatorNotOkButWhiteLis checkCalledNum := int32(0) processCalledNum := int32(0) throttler := createMockThrottler() + msgHash := []byte("hash") interceptedData := &testscommon.InterceptedDataStub{ CheckValidityCalled: func() error { return nil @@ -658,6 +709,9 @@ func TestMultiDataInterceptor_ProcessReceivedMessageIsOriginatorNotOkButWhiteLis IsForCurrentShardCalled: func() bool { return false }, + HashCalled: func() []byte { + return msgHash + }, } whiteListHandler := &testscommon.WhiteListHandlerStub{ @@ -686,7 +740,7 @@ func TestMultiDataInterceptor_ProcessReceivedMessageIsOriginatorNotOkButWhiteLis msg := &p2pmocks.P2PMessageMock{ DataField: dataField, } - err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) @@ -695,11 +749,12 @@ func TestMultiDataInterceptor_ProcessReceivedMessageIsOriginatorNotOkButWhiteLis assert.Equal(t, int32(2), atomic.LoadInt32(&processCalledNum)) assert.Equal(t, int32(1), throttler.StartProcessingCount()) assert.Equal(t, int32(1), throttler.EndProcessingCount()) + assert.Equal(t, msgHash, msgID) whiteListHandler.IsWhiteListedCalled = func(interceptedData process.InterceptedData) bool { return false } - err = mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err = mdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) assert.Equal(t, err, errOriginator) @@ -707,6 +762,7 @@ func TestMultiDataInterceptor_ProcessReceivedMessageIsOriginatorNotOkButWhiteLis assert.Equal(t, int32(2), atomic.LoadInt32(&processCalledNum)) assert.Equal(t, int32(2), throttler.StartProcessingCount()) assert.Equal(t, int32(2), throttler.EndProcessingCount()) + assert.Nil(t, msgID) } func TestMultiDataInterceptor_RegisterHandler(t *testing.T) { diff --git a/process/interceptors/processor/argHdrInterceptorProcessor.go b/process/interceptors/processor/argHdrInterceptorProcessor.go index 53e79b731b7..0f9616fb2cf 100644 --- a/process/interceptors/processor/argHdrInterceptorProcessor.go +++ b/process/interceptors/processor/argHdrInterceptorProcessor.go @@ -1,12 +1,15 @@ package processor import ( + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" ) // ArgHdrInterceptorProcessor is the argument for the interceptor processor used for headers (shard, meta and so on) type ArgHdrInterceptorProcessor struct { - Headers dataRetriever.HeadersPool - BlockBlackList process.TimeCacher + Headers dataRetriever.HeadersPool + Proofs dataRetriever.ProofsPool + BlockBlackList process.TimeCacher + EnableEpochsHandler common.EnableEpochsHandler } diff --git a/process/interceptors/processor/equivalentProofsInterceptorProcessor.go b/process/interceptors/processor/equivalentProofsInterceptorProcessor.go new file mode 100644 index 00000000000..3b9e6997e27 --- /dev/null +++ b/process/interceptors/processor/equivalentProofsInterceptorProcessor.go @@ -0,0 +1,37 @@ +package processor + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/process" +) + +// equivalentProofsInterceptorProcessor is the processor used when intercepting equivalent proofs +type equivalentProofsInterceptorProcessor struct { +} + +// NewEquivalentProofsInterceptorProcessor creates a new equivalentProofsInterceptorProcessor +func NewEquivalentProofsInterceptorProcessor() *equivalentProofsInterceptorProcessor { + return &equivalentProofsInterceptorProcessor{} +} + +// Validate checks if the intercepted data can be processed +// returns nil as proper validity checks are done at intercepted data level +func (epip *equivalentProofsInterceptorProcessor) Validate(_ process.InterceptedData, _ core.PeerID) error { + return nil +} + +// Save returns nil +// proof is added after validity checks, at intercepted data level +func (epip *equivalentProofsInterceptorProcessor) Save(_ process.InterceptedData, _ core.PeerID, _ string) error { + return nil +} + +// RegisterHandler registers a callback function to be notified of incoming equivalent proofs +func (epip *equivalentProofsInterceptorProcessor) RegisterHandler(_ func(topic string, hash []byte, data interface{})) { + log.Error("equivalentProofsInterceptorProcessor.RegisterHandler", "error", "not implemented") +} + +// IsInterfaceNil returns true if there is no value under the interface +func (epip *equivalentProofsInterceptorProcessor) IsInterfaceNil() bool { + return epip == nil +} diff --git a/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go b/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go new file mode 100644 index 00000000000..943134404ad --- /dev/null +++ b/process/interceptors/processor/equivalentProofsInterceptorProcessor_test.go @@ -0,0 +1,34 @@ +package processor + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestEquivalentProofsInterceptorProcessor_IsInterfaceNil(t *testing.T) { + t.Parallel() + + var epip *equivalentProofsInterceptorProcessor + require.True(t, epip.IsInterfaceNil()) + + epip = NewEquivalentProofsInterceptorProcessor() + require.False(t, epip.IsInterfaceNil()) +} + +func TestNewEquivalentProofsInterceptorProcessor(t *testing.T) { + t.Parallel() + + epip := NewEquivalentProofsInterceptorProcessor() + require.NotNil(t, epip) + + // coverage only + require.Nil(t, epip.Validate(nil, "")) + + // coverage only + err := epip.Save(nil, "", "") + require.Nil(t, err) + + // coverage only + epip.RegisterHandler(nil) +} diff --git a/process/interceptors/processor/hdrInterceptorProcessor.go b/process/interceptors/processor/hdrInterceptorProcessor.go index b71d5b73e59..524153a136a 100644 --- a/process/interceptors/processor/hdrInterceptorProcessor.go +++ b/process/interceptors/processor/hdrInterceptorProcessor.go @@ -6,6 +6,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" ) @@ -15,10 +17,12 @@ var _ process.InterceptorProcessor = (*HdrInterceptorProcessor)(nil) // HdrInterceptorProcessor is the processor used when intercepting headers // (shard headers, meta headers) structs which satisfy HeaderHandler interface. type HdrInterceptorProcessor struct { - headers dataRetriever.HeadersPool - blackList process.TimeCacher - registeredHandlers []func(topic string, hash []byte, data interface{}) - mutHandlers sync.RWMutex + headers dataRetriever.HeadersPool + proofs dataRetriever.ProofsPool + blackList process.TimeCacher + enableEpochsHandler common.EnableEpochsHandler + registeredHandlers []func(topic string, hash []byte, data interface{}) + mutHandlers sync.RWMutex } // NewHdrInterceptorProcessor creates a new TxInterceptorProcessor instance @@ -29,14 +33,22 @@ func NewHdrInterceptorProcessor(argument *ArgHdrInterceptorProcessor) (*HdrInter if check.IfNil(argument.Headers) { return nil, process.ErrNilCacher } + if check.IfNil(argument.Proofs) { + return nil, process.ErrNilProofsPool + } if check.IfNil(argument.BlockBlackList) { return nil, process.ErrNilBlackListCacher } + if check.IfNil(argument.EnableEpochsHandler) { + return nil, process.ErrNilEnableEpochsHandler + } return &HdrInterceptorProcessor{ - headers: argument.Headers, - blackList: argument.BlockBlackList, - registeredHandlers: make([]func(topic string, hash []byte, data interface{}), 0), + headers: argument.Headers, + proofs: argument.Proofs, + blackList: argument.BlockBlackList, + enableEpochsHandler: argument.EnableEpochsHandler, + registeredHandlers: make([]func(topic string, hash []byte, data interface{}), 0), }, nil } diff --git a/process/interceptors/processor/hdrInterceptorProcessor_test.go b/process/interceptors/processor/hdrInterceptorProcessor_test.go index 87fe3521ff7..2e5dd17584a 100644 --- a/process/interceptors/processor/hdrInterceptorProcessor_test.go +++ b/process/interceptors/processor/hdrInterceptorProcessor_test.go @@ -4,19 +4,26 @@ import ( "testing" "time" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/stretchr/testify/assert" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors/processor" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" - "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" ) func createMockHdrArgument() *processor.ArgHdrInterceptorProcessor { arg := &processor.ArgHdrInterceptorProcessor{ - Headers: &mock.HeadersCacherStub{}, - BlockBlackList: &testscommon.TimeCacheStub{}, + Headers: &mock.HeadersCacherStub{}, + Proofs: &dataRetriever.ProofsPoolMock{}, + BlockBlackList: &testscommon.TimeCacheStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } return arg @@ -55,6 +62,28 @@ func TestNewHdrInterceptorProcessor_NilBlackListHandlerShouldErr(t *testing.T) { assert.Equal(t, process.ErrNilBlackListCacher, err) } +func TestNewHdrInterceptorProcessor_NilProofsPoolShouldErr(t *testing.T) { + t.Parallel() + + arg := createMockHdrArgument() + arg.Proofs = nil + hip, err := processor.NewHdrInterceptorProcessor(arg) + + assert.Nil(t, hip) + assert.Equal(t, process.ErrNilProofsPool, err) +} + +func TestNewHdrInterceptorProcessor_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + arg := createMockHdrArgument() + arg.EnableEpochsHandler = nil + hip, err := processor.NewHdrInterceptorProcessor(arg) + + assert.Nil(t, hip) + assert.Equal(t, process.ErrNilEnableEpochsHandler, err) +} + func TestNewHdrInterceptorProcessor_ShouldWork(t *testing.T) { t.Parallel() @@ -141,6 +170,7 @@ func TestHdrInterceptorProcessor_SaveNilDataShouldErr(t *testing.T) { func TestHdrInterceptorProcessor_SaveShouldWork(t *testing.T) { t.Parallel() + minNonceWithProof := uint64(2) hdrInterceptedData := &struct { testscommon.InterceptedDataStub mock.GetHdrHandlerStub @@ -152,7 +182,11 @@ func TestHdrInterceptorProcessor_SaveShouldWork(t *testing.T) { }, GetHdrHandlerStub: mock.GetHdrHandlerStub{ HeaderHandlerCalled: func() data.HeaderHandler { - return &testscommon.HeaderHandlerStub{} + return &testscommon.HeaderHandlerStub{ + GetNonceCalled: func() uint64 { + return minNonceWithProof + }, + } }, }, } @@ -165,6 +199,11 @@ func TestHdrInterceptorProcessor_SaveShouldWork(t *testing.T) { wasAddedHeaders = true }, } + arg.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } hip, _ := processor.NewHdrInterceptorProcessor(arg) chanCalled := make(chan struct{}, 1) diff --git a/process/interceptors/processor/heartbeatInterceptorProcessor_test.go b/process/interceptors/processor/heartbeatInterceptorProcessor_test.go index 3a2c3a03aff..1667e35abc6 100644 --- a/process/interceptors/processor/heartbeatInterceptorProcessor_test.go +++ b/process/interceptors/processor/heartbeatInterceptorProcessor_test.go @@ -6,19 +6,21 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/assert" + heartbeatMessages "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/heartbeat" "github.com/multiversx/mx-chain-go/process/interceptors/processor" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" ) func createHeartbeatInterceptorProcessArg() processor.ArgHeartbeatInterceptorProcessor { return processor.ArgHeartbeatInterceptorProcessor{ - HeartbeatCacher: testscommon.NewCacherStub(), + HeartbeatCacher: cache.NewCacherStub(), ShardCoordinator: &testscommon.ShardsCoordinatorMock{}, PeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, } @@ -133,7 +135,7 @@ func TestHeartbeatInterceptorProcessor_Save(t *testing.T) { wasCalled := false providedPid := core.PeerID("pid") arg := createHeartbeatInterceptorProcessArg() - arg.HeartbeatCacher = &testscommon.CacherStub{ + arg.HeartbeatCacher = &cache.CacherStub{ PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { assert.True(t, bytes.Equal(providedPid.Bytes(), key)) ihb := value.(*heartbeatMessages.HeartbeatV2) diff --git a/process/interceptors/processor/interface.go b/process/interceptors/processor/interface.go index 147d8f30270..ab3baf43c5f 100644 --- a/process/interceptors/processor/interface.go +++ b/process/interceptors/processor/interface.go @@ -1,6 +1,7 @@ package processor import ( + "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-go/state" ) @@ -24,3 +25,13 @@ type interceptedValidatorInfo interface { Hash() []byte ValidatorInfo() *state.ShardValidatorInfo } + +// EquivalentProofsPool defines the behaviour of a proofs pool components +type EquivalentProofsPool interface { + AddProof(headerProof data.HeaderProofHandler) bool + CleanupProofsBehindNonce(shardID uint32, nonce uint64) error + GetProof(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) + GetProofByNonce(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) + HasProof(shardID uint32, headerHash []byte) bool + IsInterfaceNil() bool +} diff --git a/process/interceptors/processor/miniblockInterceptorProcessor_test.go b/process/interceptors/processor/miniblockInterceptorProcessor_test.go index eff36ae8281..149befd1a98 100644 --- a/process/interceptors/processor/miniblockInterceptorProcessor_test.go +++ b/process/interceptors/processor/miniblockInterceptorProcessor_test.go @@ -6,13 +6,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/interceptedBlocks" "github.com/multiversx/mx-chain-go/process/interceptors/processor" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - "github.com/stretchr/testify/assert" ) var testMarshalizer = &mock.MarshalizerMock{} @@ -20,7 +22,7 @@ var testHasher = &hashingMocks.HasherMock{} func createMockMiniblockArgument() *processor.ArgMiniblockInterceptorProcessor { return &processor.ArgMiniblockInterceptorProcessor{ - MiniblockCache: testscommon.NewCacherStub(), + MiniblockCache: cache.NewCacherStub(), Marshalizer: testMarshalizer, Hasher: testHasher, ShardCoordinator: mock.NewOneShardCoordinatorMock(), @@ -103,7 +105,7 @@ func TestNewMiniblockInterceptorProcessor_ShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- Validate +// ------- Validate func TestMiniblockInterceptorProcessor_ValidateShouldWork(t *testing.T) { t.Parallel() @@ -113,7 +115,7 @@ func TestMiniblockInterceptorProcessor_ValidateShouldWork(t *testing.T) { assert.Nil(t, mip.Validate(nil, "")) } -//------- Save +// ------- Save func TestMiniblockInterceptorProcessor_SaveWrongTypeAssertion(t *testing.T) { t.Parallel() @@ -129,7 +131,7 @@ func TestMiniblockInterceptorProcessor_NilMiniblockShouldNotAdd(t *testing.T) { t.Parallel() arg := createMockMiniblockArgument() - cacher := arg.MiniblockCache.(*testscommon.CacherStub) + cacher := arg.MiniblockCache.(*cache.CacherStub) cacher.HasOrAddCalled = func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { assert.Fail(t, "hasOrAdd should have not been called") return @@ -152,7 +154,7 @@ func TestMiniblockInterceptorProcessor_SaveMiniblockNotForCurrentShardShouldNotA } arg := createMockMiniblockArgument() - cacher := arg.MiniblockCache.(*testscommon.CacherStub) + cacher := arg.MiniblockCache.(*cache.CacherStub) cacher.HasOrAddCalled = func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { assert.Fail(t, "hasOrAdd should have not been called") return @@ -174,7 +176,7 @@ func TestMiniblockInterceptorProcessor_SaveMiniblockWithSenderInSameShardShouldA } arg := createMockMiniblockArgument() - cacher := arg.MiniblockCache.(*testscommon.CacherStub) + cacher := arg.MiniblockCache.(*cache.CacherStub) cacher.HasOrAddCalled = func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { _, ok := value.(*block.MiniBlock) if !ok { @@ -204,7 +206,7 @@ func TestMiniblockInterceptorProcessor_SaveMiniblocksWithReceiverInSameShardShou } arg := createMockMiniblockArgument() - cacher := arg.MiniblockCache.(*testscommon.CacherStub) + cacher := arg.MiniblockCache.(*cache.CacherStub) cacher.HasOrAddCalled = func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { _, ok := value.(*block.MiniBlock) if !ok { @@ -248,7 +250,7 @@ func TestMiniblockInterceptorProcessor_SaveMiniblockCrossShardForMeNotWhiteListe return false } - cacher := arg.MiniblockCache.(*testscommon.CacherStub) + cacher := arg.MiniblockCache.(*cache.CacherStub) cacher.HasOrAddCalled = func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { assert.Fail(t, "hasOrAdd should have not been called") return @@ -277,7 +279,7 @@ func TestMiniblockInterceptorProcessor_SaveMiniblockCrossShardForMeWhiteListedSh } addedInPool := false - cacher := arg.MiniblockCache.(*testscommon.CacherStub) + cacher := arg.MiniblockCache.(*cache.CacherStub) cacher.HasOrAddCalled = func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { addedInPool = true return false, true diff --git a/process/interceptors/processor/peerAuthenticationInterceptorProcessor_test.go b/process/interceptors/processor/peerAuthenticationInterceptorProcessor_test.go index 38a56751f05..3a1db0b6b66 100644 --- a/process/interceptors/processor/peerAuthenticationInterceptorProcessor_test.go +++ b/process/interceptors/processor/peerAuthenticationInterceptorProcessor_test.go @@ -6,6 +6,8 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/assert" + heartbeatMessages "github.com/multiversx/mx-chain-go/heartbeat" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/heartbeat" @@ -13,9 +15,9 @@ import ( "github.com/multiversx/mx-chain-go/process/interceptors/processor" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" ) type interceptedDataHandler interface { @@ -25,7 +27,7 @@ type interceptedDataHandler interface { func createPeerAuthenticationInterceptorProcessArg() processor.ArgPeerAuthenticationInterceptorProcessor { return processor.ArgPeerAuthenticationInterceptorProcessor{ - PeerAuthenticationCacher: testscommon.NewCacherStub(), + PeerAuthenticationCacher: cache.NewCacherStub(), PeerShardMapper: &p2pmocks.NetworkShardingCollectorStub{}, Marshaller: marshallerMock.MarshalizerMock{}, HardforkTrigger: &testscommon.HardforkTriggerStub{}, @@ -188,7 +190,7 @@ func TestPeerAuthenticationInterceptorProcessor_Save(t *testing.T) { wasPutCalled := false providedPid := core.PeerID("pid") arg := createPeerAuthenticationInterceptorProcessArg() - arg.PeerAuthenticationCacher = &testscommon.CacherStub{ + arg.PeerAuthenticationCacher = &cache.CacherStub{ PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { assert.Equal(t, providedIPAMessage.Pubkey, key) ipa := value.(*heartbeatMessages.PeerAuthentication) diff --git a/process/interceptors/processor/trieNodeChunksProcessor_test.go b/process/interceptors/processor/trieNodeChunksProcessor_test.go index f6602cddf67..ad63ca7adc6 100644 --- a/process/interceptors/processor/trieNodeChunksProcessor_test.go +++ b/process/interceptors/processor/trieNodeChunksProcessor_test.go @@ -9,8 +9,11 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/batch" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) @@ -31,7 +34,7 @@ func createMockTrieNodesChunksProcessorArgs() TrieNodesChunksProcessorArgs { return 32 }, }, - ChunksCacher: testscommon.NewCacherMock(), + ChunksCacher: cache.NewCacherMock(), RequestInterval: time.Second, RequestHandler: &testscommon.RequestHandlerStub{}, Topic: "topic", diff --git a/process/interceptors/processor/trieNodeInterceptorProcessor_test.go b/process/interceptors/processor/trieNodeInterceptorProcessor_test.go index d0bf3f66c27..b580f4ab65a 100644 --- a/process/interceptors/processor/trieNodeInterceptorProcessor_test.go +++ b/process/interceptors/processor/trieNodeInterceptorProcessor_test.go @@ -4,10 +4,12 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors/processor" "github.com/multiversx/mx-chain-go/process/mock" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) @@ -22,27 +24,27 @@ func TestNewTrieNodesInterceptorProcessor_NilCacherShouldErr(t *testing.T) { func TestNewTrieNodesInterceptorProcessor_OkValsShouldWork(t *testing.T) { t.Parallel() - tnip, err := processor.NewTrieNodesInterceptorProcessor(testscommon.NewCacherMock()) + tnip, err := processor.NewTrieNodesInterceptorProcessor(cache.NewCacherMock()) assert.Nil(t, err) assert.NotNil(t, tnip) } -//------- Validate +// ------- Validate func TestTrieNodesInterceptorProcessor_ValidateShouldWork(t *testing.T) { t.Parallel() - tnip, _ := processor.NewTrieNodesInterceptorProcessor(testscommon.NewCacherMock()) + tnip, _ := processor.NewTrieNodesInterceptorProcessor(cache.NewCacherMock()) assert.Nil(t, tnip.Validate(nil, "")) } -//------- Save +// ------- Save func TestTrieNodesInterceptorProcessor_SaveWrongTypeAssertion(t *testing.T) { t.Parallel() - tnip, _ := processor.NewTrieNodesInterceptorProcessor(testscommon.NewCacherMock()) + tnip, _ := processor.NewTrieNodesInterceptorProcessor(cache.NewCacherMock()) err := tnip.Save(nil, "", "") assert.Equal(t, process.ErrWrongTypeAssertion, err) @@ -61,7 +63,7 @@ func TestTrieNodesInterceptorProcessor_SaveShouldPutInCacher(t *testing.T) { } putCalled := false - cacher := &testscommon.CacherStub{ + cacher := &cache.CacherStub{ PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { putCalled = true assert.Equal(t, len(nodeHash)+nodeSize, sizeInBytes) @@ -75,7 +77,7 @@ func TestTrieNodesInterceptorProcessor_SaveShouldPutInCacher(t *testing.T) { assert.True(t, putCalled) } -//------- IsInterfaceNil +// ------- IsInterfaceNil func TestTrieNodesInterceptorProcessor_IsInterfaceNil(t *testing.T) { t.Parallel() diff --git a/process/interceptors/singleDataInterceptor.go b/process/interceptors/singleDataInterceptor.go index 84f3296acd7..da15d00170e 100644 --- a/process/interceptors/singleDataInterceptor.go +++ b/process/interceptors/singleDataInterceptor.go @@ -1,8 +1,11 @@ package interceptors import ( + "errors" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/debug/handler" "github.com/multiversx/mx-chain-go/p2p" @@ -11,14 +14,15 @@ import ( // ArgSingleDataInterceptor is the argument for the single-data interceptor type ArgSingleDataInterceptor struct { - Topic string - DataFactory process.InterceptedDataFactory - Processor process.InterceptorProcessor - Throttler process.InterceptorThrottler - AntifloodHandler process.P2PAntifloodHandler - WhiteListRequest process.WhiteListHandler - PreferredPeersHolder process.PreferredPeersHolderHandler - CurrentPeerId core.PeerID + Topic string + DataFactory process.InterceptedDataFactory + Processor process.InterceptorProcessor + Throttler process.InterceptorThrottler + AntifloodHandler process.P2PAntifloodHandler + WhiteListRequest process.WhiteListHandler + PreferredPeersHolder process.PreferredPeersHolderHandler + CurrentPeerId core.PeerID + InterceptedDataVerifier process.InterceptedDataVerifier } // SingleDataInterceptor is used for intercepting packed multi data @@ -51,19 +55,23 @@ func NewSingleDataInterceptor(arg ArgSingleDataInterceptor) (*SingleDataIntercep if check.IfNil(arg.PreferredPeersHolder) { return nil, process.ErrNilPreferredPeersHolder } + if check.IfNil(arg.InterceptedDataVerifier) { + return nil, process.ErrNilInterceptedDataVerifier + } if len(arg.CurrentPeerId) == 0 { return nil, process.ErrEmptyPeerID } singleDataIntercept := &SingleDataInterceptor{ baseDataInterceptor: &baseDataInterceptor{ - throttler: arg.Throttler, - antifloodHandler: arg.AntifloodHandler, - topic: arg.Topic, - currentPeerId: arg.CurrentPeerId, - processor: arg.Processor, - preferredPeersHolder: arg.PreferredPeersHolder, - debugHandler: handler.NewDisabledInterceptorDebugHandler(), + throttler: arg.Throttler, + antifloodHandler: arg.AntifloodHandler, + topic: arg.Topic, + currentPeerId: arg.CurrentPeerId, + processor: arg.Processor, + preferredPeersHolder: arg.PreferredPeersHolder, + debugHandler: handler.NewDisabledInterceptorDebugHandler(), + interceptedDataVerifier: arg.InterceptedDataVerifier, }, factory: arg.DataFactory, whiteListRequest: arg.WhiteListRequest, @@ -74,13 +82,13 @@ func NewSingleDataInterceptor(arg ArgSingleDataInterceptor) (*SingleDataIntercep // ProcessReceivedMessage is the callback func from the p2p.Messenger and will be called each time a new message was received // (for the topic this validator was registered to) -func (sdi *SingleDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, _ p2p.MessageHandler) error { +func (sdi *SingleDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, _ p2p.MessageHandler) ([]byte, error) { err := sdi.preProcessMesage(message, fromConnectedPeer) if err != nil { - return err + return nil, err } - interceptedData, err := sdi.factory.Create(message.Data()) + interceptedData, err := sdi.factory.Create(message.Data(), message.Peer()) if err != nil { sdi.throttler.EndProcessing() @@ -89,17 +97,16 @@ func (sdi *SingleDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, sdi.antifloodHandler.BlacklistPeer(message.Peer(), reason, common.InvalidMessageBlacklistDuration) sdi.antifloodHandler.BlacklistPeer(fromConnectedPeer, reason, common.InvalidMessageBlacklistDuration) - return err + return nil, err } sdi.receivedDebugInterceptedData(interceptedData) - - err = interceptedData.CheckValidity() + err = sdi.interceptedDataVerifier.Verify(interceptedData) if err != nil { sdi.throttler.EndProcessing() sdi.processDebugInterceptedData(interceptedData, err) - isWrongVersion := err == process.ErrInvalidTransactionVersion || err == process.ErrInvalidChainID + isWrongVersion := errors.Is(err, process.ErrInvalidTransactionVersion) || errors.Is(err, process.ErrInvalidChainID) if isWrongVersion { // this situation is so severe that we need to black list de peers reason := "wrong version of received intercepted data, topic " + sdi.topic + ", error " + err.Error() @@ -107,7 +114,7 @@ func (sdi *SingleDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, sdi.antifloodHandler.BlacklistPeer(fromConnectedPeer, reason, common.InvalidMessageBlacklistDuration) } - return err + return nil, err } errOriginator := sdi.antifloodHandler.IsOriginatorEligibleForTopic(message.Peer(), sdi.topic) @@ -117,9 +124,10 @@ func (sdi *SingleDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, "originator", p2p.PeerIdToShortString(message.Peer()), "topic", sdi.topic, "err", errOriginator) sdi.throttler.EndProcessing() - return errOriginator + return nil, errOriginator } + messageID := interceptedData.Hash() isForCurrentShard := interceptedData.IsForCurrentShard() shouldProcess := isForCurrentShard || isWhiteListed if !shouldProcess { @@ -133,7 +141,7 @@ func (sdi *SingleDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, "is white listed", isWhiteListed, ) - return nil + return messageID, nil } go func() { @@ -141,7 +149,7 @@ func (sdi *SingleDataInterceptor) ProcessReceivedMessage(message p2p.MessageP2P, sdi.throttler.EndProcessing() }() - return nil + return messageID, nil } // RegisterHandler registers a callback function to be notified on received data diff --git a/process/interceptors/singleDataInterceptor_test.go b/process/interceptors/singleDataInterceptor_test.go index 515c2a8724c..84aa285ff6c 100644 --- a/process/interceptors/singleDataInterceptor_test.go +++ b/process/interceptors/singleDataInterceptor_test.go @@ -8,25 +8,27 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockArgSingleDataInterceptor() interceptors.ArgSingleDataInterceptor { return interceptors.ArgSingleDataInterceptor{ - Topic: "test topic", - DataFactory: &mock.InterceptedDataFactoryStub{}, - Processor: &mock.InterceptorProcessorStub{}, - Throttler: createMockThrottler(), - AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, - WhiteListRequest: &testscommon.WhiteListHandlerStub{}, - PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, - CurrentPeerId: "pid", + Topic: "test topic", + DataFactory: &mock.InterceptedDataFactoryStub{}, + Processor: &mock.InterceptorProcessorStub{}, + Throttler: createMockThrottler(), + AntifloodHandler: &mock.P2PAntifloodHandlerStub{}, + WhiteListRequest: &testscommon.WhiteListHandlerStub{}, + PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, + CurrentPeerId: "pid", + InterceptedDataVerifier: createMockInterceptedDataVerifier(), } } @@ -57,6 +59,14 @@ func createMockThrottler() *mock.InterceptorThrottlerStub { } } +func createMockInterceptedDataVerifier() *mock.InterceptedDataVerifierMock { + return &mock.InterceptedDataVerifierMock{ + VerifyCalled: func(interceptedData process.InterceptedData) error { + return interceptedData.CheckValidity() + }, + } +} + func TestNewSingleDataInterceptor_EmptyTopicShouldErr(t *testing.T) { t.Parallel() @@ -145,6 +155,17 @@ func TestNewSingleDataInterceptor_EmptyPeerIDShouldErr(t *testing.T) { assert.Equal(t, process.ErrEmptyPeerID, err) } +func TestNewSingleDataInterceptor_NilInterceptedDataVerifierShouldErr(t *testing.T) { + t.Parallel() + + arg := createMockArgMultiDataInterceptor() + arg.InterceptedDataVerifier = nil + mdi, err := interceptors.NewMultiDataInterceptor(arg) + + assert.True(t, check.IfNil(mdi)) + assert.Equal(t, process.ErrNilInterceptedDataVerifier, err) +} + func TestNewSingleDataInterceptor(t *testing.T) { t.Parallel() @@ -156,7 +177,7 @@ func TestNewSingleDataInterceptor(t *testing.T) { assert.Equal(t, arg.Topic, sdi.Topic()) } -//------- ProcessReceivedMessage +// ------- ProcessReceivedMessage func TestSingleDataInterceptor_ProcessReceivedMessageNilMessageShouldErr(t *testing.T) { t.Parallel() @@ -164,9 +185,10 @@ func TestSingleDataInterceptor_ProcessReceivedMessageNilMessageShouldErr(t *test arg := createMockArgSingleDataInterceptor() sdi, _ := interceptors.NewSingleDataInterceptor(arg) - err := sdi.ProcessReceivedMessage(nil, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := sdi.ProcessReceivedMessage(nil, fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, process.ErrNilMessage, err) + assert.Nil(t, msgID) } func TestSingleDataInterceptor_ProcessReceivedMessageFactoryCreationErrorShouldErr(t *testing.T) { @@ -198,11 +220,12 @@ func TestSingleDataInterceptor_ProcessReceivedMessageFactoryCreationErrorShouldE DataField: []byte("data to be processed"), PeerField: originatorPid, } - err := sdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := sdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, errExpected, err) assert.True(t, originatorBlackListed) assert.True(t, fromConnectedPeerBlackListed) + assert.Nil(t, msgID) } func TestSingleDataInterceptor_ProcessReceivedMessageIsNotValidShouldNotCallProcess(t *testing.T) { @@ -250,7 +273,7 @@ func testProcessReceiveMessage(t *testing.T, isForCurrentShard bool, validityErr msg := &p2pmocks.P2PMessageMock{ DataField: []byte("data to be processed"), } - err := sdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := sdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) @@ -259,6 +282,7 @@ func testProcessReceiveMessage(t *testing.T, isForCurrentShard bool, validityErr assert.Equal(t, int32(calledNum), atomic.LoadInt32(&processCalledNum)) assert.Equal(t, int32(1), throttler.EndProcessingCount()) assert.Equal(t, int32(1), throttler.EndProcessingCount()) + assert.Nil(t, msgID) } func TestSingleDataInterceptor_ProcessReceivedMessageWhitelistedShouldWork(t *testing.T) { @@ -267,6 +291,7 @@ func TestSingleDataInterceptor_ProcessReceivedMessageWhitelistedShouldWork(t *te checkCalledNum := int32(0) processCalledNum := int32(0) throttler := createMockThrottler() + msgHash := []byte("hash") interceptedData := &testscommon.InterceptedDataStub{ CheckValidityCalled: func() error { return nil @@ -274,6 +299,9 @@ func TestSingleDataInterceptor_ProcessReceivedMessageWhitelistedShouldWork(t *te IsForCurrentShardCalled: func() bool { return false }, + HashCalled: func() []byte { + return msgHash + }, } arg := createMockArgSingleDataInterceptor() @@ -294,7 +322,7 @@ func TestSingleDataInterceptor_ProcessReceivedMessageWhitelistedShouldWork(t *te msg := &p2pmocks.P2PMessageMock{ DataField: []byte("data to be processed"), } - err := sdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := sdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) @@ -303,6 +331,7 @@ func TestSingleDataInterceptor_ProcessReceivedMessageWhitelistedShouldWork(t *te assert.Equal(t, int32(1), atomic.LoadInt32(&processCalledNum)) assert.Equal(t, int32(1), throttler.EndProcessingCount()) assert.Equal(t, int32(1), throttler.EndProcessingCount()) + assert.Equal(t, msgHash, msgID) } func TestSingleDataInterceptor_InvalidTxVersionShouldBlackList(t *testing.T) { @@ -362,10 +391,12 @@ func processReceivedMessageSingleDataInvalidVersion(t *testing.T, expectedErr er DataField: []byte("data to be processed"), PeerField: originator, } - err := sdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + + msgID, err := sdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) assert.Equal(t, expectedErr, err) assert.True(t, isFromConnectedPeerBlackListed) assert.True(t, isOriginatorBlackListed) + assert.Nil(t, msgID) } func TestSingleDataInterceptor_ProcessReceivedMessageWithOriginator(t *testing.T) { @@ -374,6 +405,7 @@ func TestSingleDataInterceptor_ProcessReceivedMessageWithOriginator(t *testing.T checkCalledNum := int32(0) processCalledNum := int32(0) throttler := createMockThrottler() + msgHash := []byte("hash") interceptedData := &testscommon.InterceptedDataStub{ CheckValidityCalled: func() error { return nil @@ -381,6 +413,9 @@ func TestSingleDataInterceptor_ProcessReceivedMessageWithOriginator(t *testing.T IsForCurrentShardCalled: func() bool { return false }, + HashCalled: func() []byte { + return msgHash + }, } whiteListHandler := &testscommon.WhiteListHandlerStub{ @@ -407,7 +442,7 @@ func TestSingleDataInterceptor_ProcessReceivedMessageWithOriginator(t *testing.T msg := &p2pmocks.P2PMessageMock{ DataField: []byte("data to be processed"), } - err := sdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err := sdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) @@ -416,12 +451,13 @@ func TestSingleDataInterceptor_ProcessReceivedMessageWithOriginator(t *testing.T assert.Equal(t, int32(1), atomic.LoadInt32(&processCalledNum)) assert.Equal(t, int32(1), throttler.EndProcessingCount()) assert.Equal(t, int32(1), throttler.EndProcessingCount()) + assert.Equal(t, msgHash, msgID) whiteListHandler.IsWhiteListedCalled = func(interceptedData process.InterceptedData) bool { return false } - err = sdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) + msgID, err = sdi.ProcessReceivedMessage(msg, fromConnectedPeerId, &p2pmocks.MessengerStub{}) time.Sleep(time.Second) @@ -430,9 +466,10 @@ func TestSingleDataInterceptor_ProcessReceivedMessageWithOriginator(t *testing.T assert.Equal(t, int32(1), atomic.LoadInt32(&processCalledNum)) assert.Equal(t, int32(2), throttler.EndProcessingCount()) assert.Equal(t, int32(2), throttler.EndProcessingCount()) + assert.Nil(t, msgID) } -//------- debug +// ------- debug func TestSingleDataInterceptor_SetInterceptedDebugHandlerNilShouldErr(t *testing.T) { t.Parallel() @@ -455,7 +492,7 @@ func TestSingleDataInterceptor_SetInterceptedDebugHandlerShouldWork(t *testing.T err := sdi.SetInterceptedDebugHandler(debugger) assert.Nil(t, err) - assert.True(t, debugger == sdi.InterceptedDebugHandler()) //pointer testing + assert.True(t, debugger == sdi.InterceptedDebugHandler()) // pointer testing } func TestSingleDataInterceptor_Close(t *testing.T) { @@ -468,7 +505,7 @@ func TestSingleDataInterceptor_Close(t *testing.T) { assert.Nil(t, err) } -//------- IsInterfaceNil +// ------- IsInterfaceNil func TestSingleDataInterceptor_IsInterfaceNil(t *testing.T) { t.Parallel() diff --git a/process/interceptors/whiteListDataVerifier_test.go b/process/interceptors/whiteListDataVerifier_test.go index c1567465fcc..f974f2f2c02 100644 --- a/process/interceptors/whiteListDataVerifier_test.go +++ b/process/interceptors/whiteListDataVerifier_test.go @@ -6,8 +6,11 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) @@ -23,7 +26,7 @@ func TestNewWhiteListDataVerifier_NilCacherShouldErr(t *testing.T) { func TestNewWhiteListDataVerifier_ShouldWork(t *testing.T) { t.Parallel() - wldv, err := NewWhiteListDataVerifier(testscommon.NewCacherStub()) + wldv, err := NewWhiteListDataVerifier(cache.NewCacherStub()) assert.False(t, check.IfNil(wldv)) assert.Nil(t, err) @@ -34,7 +37,7 @@ func TestWhiteListDataVerifier_Add(t *testing.T) { keys := [][]byte{[]byte("key1"), []byte("key2")} added := map[string]struct{}{} - cacher := &testscommon.CacherStub{ + cacher := &cache.CacherStub{ PutCalled: func(key []byte, value interface{}, sizeInBytes int) (evicted bool) { added[string(key)] = struct{}{} return false @@ -55,7 +58,7 @@ func TestWhiteListDataVerifier_Remove(t *testing.T) { keys := [][]byte{[]byte("key1"), []byte("key2")} removed := map[string]struct{}{} - cacher := &testscommon.CacherStub{ + cacher := &cache.CacherStub{ RemoveCalled: func(key []byte) { removed[string(key)] = struct{}{} }, @@ -73,7 +76,7 @@ func TestWhiteListDataVerifier_Remove(t *testing.T) { func TestWhiteListDataVerifier_IsWhiteListedNilInterceptedDataShouldRetFalse(t *testing.T) { t.Parallel() - wldv, _ := NewWhiteListDataVerifier(testscommon.NewCacherStub()) + wldv, _ := NewWhiteListDataVerifier(cache.NewCacherStub()) assert.False(t, wldv.IsWhiteListed(nil)) } @@ -83,7 +86,7 @@ func TestWhiteListDataVerifier_IsWhiteListedNotFoundShouldRetFalse(t *testing.T) keyCheck := []byte("key") wldv, _ := NewWhiteListDataVerifier( - &testscommon.CacherStub{ + &cache.CacherStub{ HasCalled: func(key []byte) bool { return !bytes.Equal(key, keyCheck) }, @@ -104,7 +107,7 @@ func TestWhiteListDataVerifier_IsWhiteListedFoundShouldRetTrue(t *testing.T) { keyCheck := []byte("key") wldv, _ := NewWhiteListDataVerifier( - &testscommon.CacherStub{ + &cache.CacherStub{ HasCalled: func(key []byte) bool { return bytes.Equal(key, keyCheck) }, diff --git a/process/interface.go b/process/interface.go index 13e825f4713..bf8d4040b26 100644 --- a/process/interface.go +++ b/process/interface.go @@ -25,6 +25,7 @@ import ( "github.com/multiversx/mx-chain-go/common" cryptoCommon "github.com/multiversx/mx-chain-go/common/crypto" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process/block/bootstrapStorage" @@ -116,7 +117,7 @@ type HdrValidatorHandler interface { // InterceptedDataFactory can create new instances of InterceptedData type InterceptedDataFactory interface { - Create(buff []byte) (InterceptedData, error) + Create(buff []byte, messageOriginator core.PeerID) (InterceptedData, error) IsInterfaceNil() bool } @@ -383,7 +384,9 @@ type ForkDetector interface { RestoreToGenesis() GetNotarizedHeaderHash(nonce uint64) []byte ResetProbableHighestNonce() + AddCheckpoint(nonce uint64, round uint64, hash []byte) SetFinalToLastCheckpoint() + ReceivedProof(proof data.HeaderProofHandler) IsInterfaceNil() bool } @@ -560,7 +563,7 @@ type BlockChainHookWithAccountsAdapter interface { // Interceptor defines what a data interceptor should do // It should also adhere to the p2p.MessageProcessor interface so it can wire to a p2p.Messenger type Interceptor interface { - ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error + ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) SetInterceptedDebugHandler(handler InterceptedDebugger) error RegisterHandler(handler func(topic string, hash []byte, data interface{})) Close() error @@ -610,6 +613,8 @@ type RequestHandler interface { RequestPeerAuthenticationsByHashes(destShardID uint32, hashes [][]byte) RequestValidatorInfo(hash []byte) RequestValidatorsInfo(hashes [][]byte) + RequestEquivalentProofByHash(headerShard uint32, headerHash []byte) + RequestEquivalentProofByNonce(headerShard uint32, headerNonce uint64) IsInterfaceNil() bool } @@ -864,6 +869,8 @@ type InterceptedHeaderSigVerifier interface { VerifyRandSeed(header data.HeaderHandler) error VerifyLeaderSignature(header data.HeaderHandler) error VerifySignature(header data.HeaderHandler) error + VerifySignatureForHash(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error + VerifyHeaderProof(headerProof data.HeaderProofHandler) error IsInterfaceNil() bool } @@ -927,12 +934,18 @@ type TopicFloodPreventer interface { IsInterfaceNil() bool } +// ChainParametersSubscriber is the interface that can be used to subscribe for chain parameters changes +type ChainParametersSubscriber interface { + RegisterNotifyHandler(handler common.ChainParametersSubscriptionHandler) + IsInterfaceNil() bool +} + // P2PAntifloodHandler defines the behavior of a component able to signal that the system is too busy (or flooded) processing // p2p messages type P2PAntifloodHandler interface { CanProcessMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error CanProcessMessagesOnTopic(pid core.PeerID, topic string, numMessages uint32, totalSize uint64, sequence []byte) error - ApplyConsensusSize(size int) + SetConsensusSizeNotifier(chainParametersNotifier ChainParametersSubscriber, shardID uint32) SetDebugger(debugger AntifloodDebugger) error BlacklistPeer(peer core.PeerID, reason string, duration time.Duration) IsOriginatorEligibleForTopic(pid core.PeerID, topic string) error @@ -1049,6 +1062,7 @@ type RatingsInfoHandler interface { MetaChainRatingsStepHandler() RatingsStepHandler ShardChainRatingsStepHandler() RatingsStepHandler SelectionChances() []SelectionChance + SetStatusHandler(handler core.AppStatusHandler) error IsInterfaceNil() bool } @@ -1061,6 +1075,36 @@ type RatingsStepHandler interface { ConsecutiveMissedBlocksPenalty() float32 } +// NodesSetupHandler returns the nodes' configuration +type NodesSetupHandler interface { + AllInitialNodes() []nodesCoordinator.GenesisNodeInfoHandler + InitialNodesPubKeys() map[uint32][]string + GetShardIDForPubKey(pubkey []byte) (uint32, error) + InitialEligibleNodesPubKeysForShard(shardId uint32) ([]string, error) + InitialNodesInfoForShard(shardId uint32) ([]nodesCoordinator.GenesisNodeInfoHandler, []nodesCoordinator.GenesisNodeInfoHandler, error) + InitialNodesInfo() (map[uint32][]nodesCoordinator.GenesisNodeInfoHandler, map[uint32][]nodesCoordinator.GenesisNodeInfoHandler) + GetStartTime() int64 + GetRoundDuration() uint64 + GetShardConsensusGroupSize() uint32 + GetMetaConsensusGroupSize() uint32 + NumberOfShards() uint32 + MinNumberOfNodes() uint32 + MinNumberOfShardNodes() uint32 + MinNumberOfMetaNodes() uint32 + GetHysteresis() float32 + GetAdaptivity() bool + MinNumberOfNodesWithHysteresis() uint32 + IsInterfaceNil() bool +} + +// ChainParametersHandler defines the actions that need to be done by a component that can handle chain parameters +type ChainParametersHandler interface { + CurrentChainParameters() config.ChainParametersByEpochConfig + AllChainParameters() []config.ChainParametersByEpochConfig + ChainParametersForEpoch(epoch uint32) (config.ChainParametersByEpochConfig, error) + IsInterfaceNil() bool +} + // ValidatorInfoSyncer defines the method needed for validatorInfoProcessing type ValidatorInfoSyncer interface { SyncMiniBlocks(headerHandler data.HeaderHandler) ([][]byte, data.BodyHandler, error) @@ -1132,6 +1176,7 @@ type EpochStartEventNotifier interface { type NodesCoordinator interface { GetValidatorWithPublicKey(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) GetAllEligibleValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) + GetAllEligibleValidatorsPublicKeysForShard(epoch uint32, shardID uint32) ([]string, error) GetAllWaitingValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) GetAllLeavingValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) IsInterfaceNil() bool @@ -1180,6 +1225,7 @@ type PayableHandler interface { // FallbackHeaderValidator defines the behaviour of a component able to signal when a fallback header validation could be applied type FallbackHeaderValidator interface { + ShouldApplyFallbackValidationForHeaderWith(shardID uint32, startOfEpochBlock bool, round uint64, prevHeaderHash []byte) bool ShouldApplyFallbackValidation(headerHandler data.HeaderHandler) bool IsInterfaceNil() bool } @@ -1200,11 +1246,15 @@ type CoreComponentsHolder interface { TxVersionChecker() TxVersionCheckerHandler GenesisNodesSetup() sharding.GenesisNodesSetupHandler EpochNotifier() EpochNotifier + ChainParametersSubscriber() ChainParametersSubscriber ChanStopNodeProcess() chan endProcess.ArgEndProcess NodeTypeProvider() core.NodeTypeProviderHandler ProcessStatusHandler() common.ProcessStatusHandler HardforkTriggerPubKey() []byte EnableEpochsHandler() common.EnableEpochsHandler + ChainParametersHandler() ChainParametersHandler + FieldsSizeChecker() common.FieldsSizeChecker + EpochChangeGracePeriodHandler() common.EpochChangeGracePeriodHandler IsInterfaceNil() bool } @@ -1375,3 +1425,24 @@ type SentSignaturesTracker interface { ResetCountersForManagedBlockSigner(signerPk []byte) IsInterfaceNil() bool } + +// InterceptedDataVerifier defines a component able to verify intercepted data validity +type InterceptedDataVerifier interface { + Verify(interceptedData InterceptedData) error + IsInterfaceNil() bool +} + +// InterceptedDataVerifierFactory defines a component that is able to create intercepted data verifiers +type InterceptedDataVerifierFactory interface { + Create(topic string) (InterceptedDataVerifier, error) + Close() error + IsInterfaceNil() bool +} + +// ProofsPool defines the behaviour of a proofs pool components +type ProofsPool interface { + AddProof(headerProof data.HeaderProofHandler) bool + HasProof(shardID uint32, headerHash []byte) bool + IsProofInPoolEqualTo(headerProof data.HeaderProofHandler) bool + IsInterfaceNil() bool +} diff --git a/process/mock/coreComponentsMock.go b/process/mock/coreComponentsMock.go index 9082cbf7b37..6f6d147d8f3 100644 --- a/process/mock/coreComponentsMock.go +++ b/process/mock/coreComponentsMock.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/process" @@ -17,29 +18,33 @@ import ( // CoreComponentsMock - type CoreComponentsMock struct { - IntMarsh marshal.Marshalizer - TxMarsh marshal.Marshalizer - Hash hashing.Hasher - TxSignHasherField hashing.Hasher - UInt64ByteSliceConv typeConverters.Uint64ByteSliceConverter - AddrPubKeyConv core.PubkeyConverter - ValPubKeyConv core.PubkeyConverter - PathHdl storage.PathManagerHandler - ChainIdCalled func() string - MinTransactionVersionCalled func() uint32 - GenesisNodesSetupCalled func() sharding.GenesisNodesSetupHandler - TxVersionCheckField process.TxVersionCheckerHandler - EpochNotifierField process.EpochNotifier - EnableEpochsHandlerField common.EnableEpochsHandler - RoundNotifierField process.RoundNotifier - EnableRoundsHandlerField process.EnableRoundsHandler - RoundField consensus.RoundHandler - StatusField core.AppStatusHandler - ChanStopNode chan endProcess.ArgEndProcess - NodeTypeProviderField core.NodeTypeProviderHandler - EconomicsDataField process.EconomicsDataHandler - ProcessStatusHandlerField common.ProcessStatusHandler - HardforkTriggerPubKeyField []byte + IntMarsh marshal.Marshalizer + TxMarsh marshal.Marshalizer + Hash hashing.Hasher + TxSignHasherField hashing.Hasher + UInt64ByteSliceConv typeConverters.Uint64ByteSliceConverter + AddrPubKeyConv core.PubkeyConverter + ValPubKeyConv core.PubkeyConverter + PathHdl storage.PathManagerHandler + ChainIdCalled func() string + MinTransactionVersionCalled func() uint32 + GenesisNodesSetupCalled func() sharding.GenesisNodesSetupHandler + TxVersionCheckField process.TxVersionCheckerHandler + EpochNotifierField process.EpochNotifier + EnableEpochsHandlerField common.EnableEpochsHandler + RoundNotifierField process.RoundNotifier + EnableRoundsHandlerField process.EnableRoundsHandler + RoundField consensus.RoundHandler + StatusField core.AppStatusHandler + ChanStopNode chan endProcess.ArgEndProcess + NodeTypeProviderField core.NodeTypeProviderHandler + EconomicsDataField process.EconomicsDataHandler + ProcessStatusHandlerField common.ProcessStatusHandler + ChainParametersHandlerField process.ChainParametersHandler + HardforkTriggerPubKeyField []byte + ChainParametersSubscriberField process.ChainParametersSubscriber + FieldsSizeCheckerField common.FieldsSizeChecker + EpochChangeGracePeriodHandlerField common.EpochChangeGracePeriodHandler } // ChanStopNodeProcess - @@ -82,6 +87,11 @@ func (ccm *CoreComponentsMock) Uint64ByteSliceConverter() typeConverters.Uint64B return ccm.UInt64ByteSliceConv } +// ChainParametersHandler - +func (ccm *CoreComponentsMock) ChainParametersHandler() process.ChainParametersHandler { + return ccm.ChainParametersHandlerField +} + // AddressPubKeyConverter - func (ccm *CoreComponentsMock) AddressPubKeyConverter() core.PubkeyConverter { return ccm.AddrPubKeyConv @@ -175,6 +185,21 @@ func (ccm *CoreComponentsMock) HardforkTriggerPubKey() []byte { return ccm.HardforkTriggerPubKeyField } +// ChainParametersSubscriber - +func (ccm *CoreComponentsMock) ChainParametersSubscriber() process.ChainParametersSubscriber { + return ccm.ChainParametersSubscriberField +} + +// FieldsSizeChecker - +func (ccm *CoreComponentsMock) FieldsSizeChecker() common.FieldsSizeChecker { + return ccm.FieldsSizeCheckerField +} + +// EpochChangeGracePeriodHandler - +func (ccm *CoreComponentsMock) EpochChangeGracePeriodHandler() common.EpochChangeGracePeriodHandler { + return ccm.EpochChangeGracePeriodHandlerField +} + // IsInterfaceNil - func (ccm *CoreComponentsMock) IsInterfaceNil() bool { return ccm == nil diff --git a/process/mock/floodPreventerStub.go b/process/mock/floodPreventerStub.go index d9d1a8881c3..85367b10545 100644 --- a/process/mock/floodPreventerStub.go +++ b/process/mock/floodPreventerStub.go @@ -11,7 +11,11 @@ type FloodPreventerStub struct { // IncreaseLoad - func (fps *FloodPreventerStub) IncreaseLoad(pid core.PeerID, size uint64) error { - return fps.IncreaseLoadCalled(pid, size) + if fps.IncreaseLoadCalled != nil { + return fps.IncreaseLoadCalled(pid, size) + } + + return nil } // ApplyConsensusSize - diff --git a/process/mock/forkDetectorMock.go b/process/mock/forkDetectorMock.go index a574e4724b1..65f75c83763 100644 --- a/process/mock/forkDetectorMock.go +++ b/process/mock/forkDetectorMock.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" ) @@ -19,6 +20,8 @@ type ForkDetectorMock struct { RestoreToGenesisCalled func() ResetProbableHighestNonceCalled func() SetFinalToLastCheckpointCalled func() + ReceivedProofCalled func(proof data.HeaderProofHandler) + AddCheckpointCalled func(nonce uint64, round uint64, hash []byte) } // RestoreToGenesis - @@ -28,17 +31,27 @@ func (fdm *ForkDetectorMock) RestoreToGenesis() { // AddHeader - func (fdm *ForkDetectorMock) AddHeader(header data.HeaderHandler, hash []byte, state process.BlockHeaderState, selfNotarizedHeaders []data.HeaderHandler, selfNotarizedHeadersHashes [][]byte) error { - return fdm.AddHeaderCalled(header, hash, state, selfNotarizedHeaders, selfNotarizedHeadersHashes) + if fdm.AddHeaderCalled != nil { + return fdm.AddHeaderCalled(header, hash, state, selfNotarizedHeaders, selfNotarizedHeadersHashes) + } + + return nil } // RemoveHeader - func (fdm *ForkDetectorMock) RemoveHeader(nonce uint64, hash []byte) { - fdm.RemoveHeaderCalled(nonce, hash) + if fdm.RemoveHeaderCalled != nil { + fdm.RemoveHeaderCalled(nonce, hash) + } } // CheckFork - func (fdm *ForkDetectorMock) CheckFork() *process.ForkInfo { - return fdm.CheckForkCalled() + if fdm.CheckForkCalled != nil { + return fdm.CheckForkCalled() + } + + return nil } // GetHighestFinalBlockNonce - @@ -51,12 +64,20 @@ func (fdm *ForkDetectorMock) GetHighestFinalBlockNonce() uint64 { // GetHighestFinalBlockHash - func (fdm *ForkDetectorMock) GetHighestFinalBlockHash() []byte { - return fdm.GetHighestFinalBlockHashCalled() + if fdm.GetHighestFinalBlockHashCalled != nil { + return fdm.GetHighestFinalBlockHashCalled() + } + + return nil } // ProbableHighestNonce - func (fdm *ForkDetectorMock) ProbableHighestNonce() uint64 { - return fdm.ProbableHighestNonceCalled() + if fdm.ProbableHighestNonceCalled != nil { + return fdm.ProbableHighestNonceCalled() + } + + return 0 } // SetRollBackNonce - @@ -68,12 +89,18 @@ func (fdm *ForkDetectorMock) SetRollBackNonce(nonce uint64) { // ResetFork - func (fdm *ForkDetectorMock) ResetFork() { - fdm.ResetForkCalled() + if fdm.ResetForkCalled != nil { + fdm.ResetForkCalled() + } } // GetNotarizedHeaderHash - func (fdm *ForkDetectorMock) GetNotarizedHeaderHash(nonce uint64) []byte { - return fdm.GetNotarizedHeaderHashCalled(nonce) + if fdm.GetNotarizedHeaderHashCalled != nil { + return fdm.GetNotarizedHeaderHashCalled(nonce) + } + + return nil } // ResetProbableHighestNonce - @@ -90,6 +117,20 @@ func (fdm *ForkDetectorMock) SetFinalToLastCheckpoint() { } } +// ReceivedProof - +func (fdm *ForkDetectorMock) ReceivedProof(proof data.HeaderProofHandler) { + if fdm.ReceivedProofCalled != nil { + fdm.ReceivedProofCalled(proof) + } +} + +// AddCheckpoint - +func (fdm *ForkDetectorMock) AddCheckpoint(nonce uint64, round uint64, hash []byte) { + if fdm.AddCheckpointCalled != nil { + fdm.AddCheckpointCalled(nonce, round, hash) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (fdm *ForkDetectorMock) IsInterfaceNil() bool { return fdm == nil diff --git a/process/mock/headerSigVerifierStub.go b/process/mock/headerSigVerifierStub.go deleted file mode 100644 index efc83c06e18..00000000000 --- a/process/mock/headerSigVerifierStub.go +++ /dev/null @@ -1,52 +0,0 @@ -package mock - -import "github.com/multiversx/mx-chain-core-go/data" - -// HeaderSigVerifierStub - -type HeaderSigVerifierStub struct { - VerifyLeaderSignatureCalled func(header data.HeaderHandler) error - VerifyRandSeedCalled func(header data.HeaderHandler) error - VerifyRandSeedAndLeaderSignatureCalled func(header data.HeaderHandler) error - VerifySignatureCalled func(header data.HeaderHandler) error -} - -// VerifyRandSeed - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeed(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedCalled != nil { - return hsvm.VerifyRandSeedCalled(header) - } - - return nil -} - -// VerifyRandSeedAndLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyRandSeedAndLeaderSignatureCalled != nil { - return hsvm.VerifyRandSeedAndLeaderSignatureCalled(header) - } - - return nil -} - -// VerifyLeaderSignature - -func (hsvm *HeaderSigVerifierStub) VerifyLeaderSignature(header data.HeaderHandler) error { - if hsvm.VerifyLeaderSignatureCalled != nil { - return hsvm.VerifyLeaderSignatureCalled(header) - } - - return nil -} - -// VerifySignature - -func (hsvm *HeaderSigVerifierStub) VerifySignature(header data.HeaderHandler) error { - if hsvm.VerifySignatureCalled != nil { - return hsvm.VerifySignatureCalled(header) - } - - return nil -} - -// IsInterfaceNil - -func (hsvm *HeaderSigVerifierStub) IsInterfaceNil() bool { - return hsvm == nil -} diff --git a/process/mock/interceptedDataFactoryStub.go b/process/mock/interceptedDataFactoryStub.go index 4a0445659e7..3481f42e5b7 100644 --- a/process/mock/interceptedDataFactoryStub.go +++ b/process/mock/interceptedDataFactoryStub.go @@ -1,6 +1,9 @@ package mock -import "github.com/multiversx/mx-chain-go/process" +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/process" +) // InterceptedDataFactoryStub - type InterceptedDataFactoryStub struct { @@ -8,7 +11,7 @@ type InterceptedDataFactoryStub struct { } // Create - -func (idfs *InterceptedDataFactoryStub) Create(buff []byte) (process.InterceptedData, error) { +func (idfs *InterceptedDataFactoryStub) Create(buff []byte, _ core.PeerID) (process.InterceptedData, error) { return idfs.CreateCalled(buff) } diff --git a/process/mock/interceptedDataVerifierFactoryMock.go b/process/mock/interceptedDataVerifierFactoryMock.go new file mode 100644 index 00000000000..245be014b15 --- /dev/null +++ b/process/mock/interceptedDataVerifierFactoryMock.go @@ -0,0 +1,29 @@ +package mock + +import ( + "github.com/multiversx/mx-chain-go/process" +) + +// InterceptedDataVerifierFactoryMock - +type InterceptedDataVerifierFactoryMock struct { + CreateCalled func(topic string) (process.InterceptedDataVerifier, error) +} + +// Create - +func (idvfs *InterceptedDataVerifierFactoryMock) Create(topic string) (process.InterceptedDataVerifier, error) { + if idvfs.CreateCalled != nil { + return idvfs.CreateCalled(topic) + } + + return &InterceptedDataVerifierMock{}, nil +} + +// Close - +func (idvfs *InterceptedDataVerifierFactoryMock) Close() error { + return nil +} + +// IsInterfaceNil - +func (idvfs *InterceptedDataVerifierFactoryMock) IsInterfaceNil() bool { + return idvfs == nil +} diff --git a/process/mock/interceptedDataVerifierMock.go b/process/mock/interceptedDataVerifierMock.go index c8d4d14392b..6668a6ea625 100644 --- a/process/mock/interceptedDataVerifierMock.go +++ b/process/mock/interceptedDataVerifierMock.go @@ -1,17 +1,24 @@ package mock -import "github.com/multiversx/mx-chain-go/process" +import ( + "github.com/multiversx/mx-chain-go/process" +) // InterceptedDataVerifierMock - type InterceptedDataVerifierMock struct { + VerifyCalled func(interceptedData process.InterceptedData) error } -// IsForCurrentShard - -func (i *InterceptedDataVerifierMock) IsForCurrentShard(_ process.InterceptedData) bool { - return true +// Verify - +func (idv *InterceptedDataVerifierMock) Verify(interceptedData process.InterceptedData) error { + if idv.VerifyCalled != nil { + return idv.VerifyCalled(interceptedData) + } + + return nil } -// IsInterfaceNil returns true if underlying object is -func (i *InterceptedDataVerifierMock) IsInterfaceNil() bool { - return i == nil +// IsInterfaceNil - +func (idv *InterceptedDataVerifierMock) IsInterfaceNil() bool { + return idv == nil } diff --git a/process/mock/p2pAntifloodHandlerStub.go b/process/mock/p2pAntifloodHandlerStub.go index 7dd6e011474..819267b91ae 100644 --- a/process/mock/p2pAntifloodHandlerStub.go +++ b/process/mock/p2pAntifloodHandlerStub.go @@ -12,7 +12,7 @@ import ( type P2PAntifloodHandlerStub struct { CanProcessMessageCalled func(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error CanProcessMessagesOnTopicCalled func(peer core.PeerID, topic string, numMessages uint32, totalSize uint64, sequence []byte) error - ApplyConsensusSizeCalled func(size int) + SetConsensusSizeNotifierCalled func(subscriber process.ChainParametersSubscriber, shardID uint32) SetDebuggerCalled func(debugger process.AntifloodDebugger) error BlacklistPeerCalled func(peer core.PeerID, reason string, duration time.Duration) IsOriginatorEligibleForTopicCalled func(pid core.PeerID, topic string) error @@ -42,10 +42,10 @@ func (p2pahs *P2PAntifloodHandlerStub) CanProcessMessagesOnTopic(peer core.PeerI return p2pahs.CanProcessMessagesOnTopicCalled(peer, topic, numMessages, totalSize, sequence) } -// ApplyConsensusSize - -func (p2pahs *P2PAntifloodHandlerStub) ApplyConsensusSize(size int) { - if p2pahs.ApplyConsensusSizeCalled != nil { - p2pahs.ApplyConsensusSizeCalled(size) +// SetConsensusSizeNotifier - +func (p2pahs *P2PAntifloodHandlerStub) SetConsensusSizeNotifier(subscriber process.ChainParametersSubscriber, shardID uint32) { + if p2pahs.SetConsensusSizeNotifierCalled != nil { + p2pahs.SetConsensusSizeNotifierCalled(subscriber, shardID) } } diff --git a/process/mock/peerShardResolverStub.go b/process/mock/peerShardResolverStub.go index 4239fbeaee4..a5bd8a66d98 100644 --- a/process/mock/peerShardResolverStub.go +++ b/process/mock/peerShardResolverStub.go @@ -11,7 +11,11 @@ type PeerShardResolverStub struct { // GetPeerInfo - func (psrs *PeerShardResolverStub) GetPeerInfo(pid core.PeerID) core.P2PPeerInfo { - return psrs.GetPeerInfoCalled(pid) + if psrs.GetPeerInfoCalled != nil { + return psrs.GetPeerInfoCalled(pid) + } + + return core.P2PPeerInfo{} } // IsInterfaceNil - diff --git a/process/mock/raterMock.go b/process/mock/raterMock.go index 5f64b6f3b7d..248fefe16b6 100644 --- a/process/mock/raterMock.go +++ b/process/mock/raterMock.go @@ -118,7 +118,10 @@ func (rm *RaterMock) GetStartRating() uint32 { // GetSignedBlocksThreshold - func (rm *RaterMock) GetSignedBlocksThreshold() float32 { - return rm.GetSignedBlocksThresholdCalled() + if rm.GetSignedBlocksThresholdCalled != nil { + return rm.GetSignedBlocksThresholdCalled() + } + return 0.01 } // ComputeIncreaseProposer - diff --git a/process/mock/ratingsInfoMock.go b/process/mock/ratingsInfoMock.go index 5378b889cda..25250ebee30 100644 --- a/process/mock/ratingsInfoMock.go +++ b/process/mock/ratingsInfoMock.go @@ -1,6 +1,9 @@ package mock -import "github.com/multiversx/mx-chain-go/process" +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/process" +) // RatingsInfoMock - type RatingsInfoMock struct { @@ -11,6 +14,7 @@ type RatingsInfoMock struct { MetaRatingsStepDataProperty process.RatingsStepHandler ShardRatingsStepDataProperty process.RatingsStepHandler SelectionChancesProperty []process.SelectionChance + SetStatusHandlerCalled func(handler core.AppStatusHandler) error } // StartRating - @@ -48,6 +52,14 @@ func (rd *RatingsInfoMock) ShardChainRatingsStepHandler() process.RatingsStepHan return rd.ShardRatingsStepDataProperty } +// SetStatusHandler - +func (rd *RatingsInfoMock) SetStatusHandler(handler core.AppStatusHandler) error { + if rd.SetStatusHandlerCalled != nil { + return rd.SetStatusHandlerCalled(handler) + } + return nil +} + // IsInterfaceNil - func (rd *RatingsInfoMock) IsInterfaceNil() bool { return rd == nil diff --git a/process/mock/rounderMock.go b/process/mock/rounderMock.go index 047c787ced9..90d5f356405 100644 --- a/process/mock/rounderMock.go +++ b/process/mock/rounderMock.go @@ -21,6 +21,12 @@ func (rndm *RoundHandlerMock) BeforeGenesis() bool { return false } +// RevertOneRound - +func (rndm *RoundHandlerMock) RevertOneRound() { + rndm.RoundIndex-- + rndm.RoundTimeStamp = rndm.RoundTimeStamp.Add(-rndm.RoundTimeDuration) +} + // Index - func (rndm *RoundHandlerMock) Index() int64 { return rndm.RoundIndex diff --git a/process/peer/process.go b/process/peer/process.go index 6ca77b7cbed..24c89a08dc6 100644 --- a/process/peer/process.go +++ b/process/peer/process.go @@ -1,6 +1,7 @@ package peer import ( + "bytes" "context" "encoding/hex" "fmt" @@ -388,7 +389,7 @@ func (vs *validatorStatistics) UpdatePeerState(header data.MetaHeaderHandler, ca log.Trace("Increasing", "round", previousHeader.GetRound(), "prevRandSeed", previousHeader.GetPrevRandSeed()) consensusGroupEpoch := computeEpoch(previousHeader) - consensusGroup, err := vs.nodesCoordinator.ComputeConsensusGroup( + leader, consensusGroup, err := vs.nodesCoordinator.ComputeConsensusGroup( previousHeader.GetPrevRandSeed(), previousHeader.GetRound(), previousHeader.GetShardID(), @@ -397,15 +398,18 @@ func (vs *validatorStatistics) UpdatePeerState(header data.MetaHeaderHandler, ca return nil, err } - encodedLeaderPk := vs.pubkeyConv.SilentEncode(consensusGroup[0].PubKey(), log) + encodedLeaderPk := vs.pubkeyConv.SilentEncode(leader.PubKey(), log) leaderPK := core.GetTrimmedPk(encodedLeaderPk) log.Trace("Increasing for leader", "leader", leaderPK, "round", previousHeader.GetRound()) log.Debug("UpdatePeerState - registering meta previous leader fees", "metaNonce", previousHeader.GetNonce()) + + bitmap := vs.getBitmapForHeader(previousHeader) err = vs.updateValidatorInfoOnSuccessfulBlock( + leader, consensusGroup, - previousHeader.GetPubKeysBitmap(), + bitmap, big.NewInt(0).Sub(previousHeader.GetAccumulatedFees(), previousHeader.GetDeveloperFees()), previousHeader.GetShardID(), previousHeader.GetEpoch(), @@ -424,6 +428,14 @@ func (vs *validatorStatistics) UpdatePeerState(header data.MetaHeaderHandler, ca return rootHash, nil } +func (vs *validatorStatistics) getBitmapForHeader(header data.HeaderHandler) []byte { + bitmap := header.GetPubKeysBitmap() + if vs.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + bitmap = vs.getBitmapForFullConsensus(header.GetShardID(), header.GetEpoch()) + } + return bitmap +} + func computeEpoch(header data.HeaderHandler) uint32 { // TODO: change if start of epoch block needs to be validated by the new epoch nodes // previous block was proposed by the consensus group of the previous epoch @@ -659,6 +671,10 @@ func (vs *validatorStatistics) verifySignaturesBelowSignedThreshold( return nil } + if vs.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, epoch) { + return nil + } + validatorOccurrences := core.MaxUint32(1, validator.GetValidatorSuccess()+validator.GetValidatorFailure()+validator.GetValidatorIgnoredSignatures()) computedThreshold := float32(validator.GetValidatorSuccess()) / float32(validatorOccurrences) @@ -800,16 +816,16 @@ func (vs *validatorStatistics) computeDecrease( swInner.Start("ComputeValidatorsGroup") log.Debug("decreasing", "round", i, "prevRandSeed", prevRandSeed, "shardId", shardID) - consensusGroup, err := vs.nodesCoordinator.ComputeConsensusGroup(prevRandSeed, i, shardID, epoch) + leader, consensusGroup, err := vs.nodesCoordinator.ComputeConsensusGroup(prevRandSeed, i, shardID, epoch) swInner.Stop("ComputeValidatorsGroup") if err != nil { return err } swInner.Start("loadPeerAccount") - leaderPeerAcc, err := vs.loadPeerAccount(consensusGroup[0].PubKey()) + leaderPeerAcc, err := vs.loadPeerAccount(leader.PubKey()) - encodedLeaderPk := vs.pubkeyConv.SilentEncode(consensusGroup[0].PubKey(), log) + encodedLeaderPk := vs.pubkeyConv.SilentEncode(leader.PubKey(), log) leaderPK := core.GetTrimmedPk(encodedLeaderPk) swInner.Stop("loadPeerAccount") if err != nil { @@ -817,7 +833,7 @@ func (vs *validatorStatistics) computeDecrease( } vs.mutValidatorStatistics.Lock() - vs.missedBlocksCounters.decreaseLeader(consensusGroup[0].PubKey()) + vs.missedBlocksCounters.decreaseLeader(leader.PubKey()) vs.mutValidatorStatistics.Unlock() swInner.Start("ComputeDecreaseProposer") @@ -895,6 +911,17 @@ func (vs *validatorStatistics) RevertPeerState(header data.MetaHeaderHandler) er return vs.peerAdapter.RecreateTrie(rootHashHolder) } +// TODO: check if this can be taken from somewhere else +func (vs *validatorStatistics) getBitmapForFullConsensus(shardID uint32, epoch uint32) []byte { + consensusSize := vs.nodesCoordinator.ConsensusGroupSizeForShardAndEpoch(shardID, epoch) + bitmap := make([]byte, consensusSize/8+1) + for i := 0; i < consensusSize; i++ { + bitmap[i/8] |= 1 << (uint16(i) % 8) + } + + return bitmap +} + func (vs *validatorStatistics) updateShardDataPeerState( header data.HeaderHandler, cacheMap map[string]data.HeaderHandler, @@ -921,15 +948,20 @@ func (vs *validatorStatistics) updateShardDataPeerState( epoch := computeEpoch(currentHeader) - shardConsensus, shardInfoErr := vs.nodesCoordinator.ComputeConsensusGroup(h.PrevRandSeed, h.Round, h.ShardID, epoch) + leader, shardConsensus, shardInfoErr := vs.nodesCoordinator.ComputeConsensusGroup(h.PrevRandSeed, h.Round, h.ShardID, epoch) if shardInfoErr != nil { return shardInfoErr } log.Debug("updateShardDataPeerState - registering shard leader fees", "shard headerHash", h.HeaderHash, "accumulatedFees", h.AccumulatedFees.String(), "developerFees", h.DeveloperFees.String()) + bitmap := h.PubKeysBitmap + if vs.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, h.Epoch) { + bitmap = vs.getBitmapForFullConsensus(h.ShardID, h.Epoch) + } shardInfoErr = vs.updateValidatorInfoOnSuccessfulBlock( + leader, shardConsensus, - h.PubKeysBitmap, + bitmap, big.NewInt(0).Sub(h.AccumulatedFees, h.DeveloperFees), h.ShardID, currentHeader.GetEpoch(), @@ -1016,6 +1048,7 @@ func (vs *validatorStatistics) savePeerAccountData( } func (vs *validatorStatistics) updateValidatorInfoOnSuccessfulBlock( + leader nodesCoordinator.Validator, validatorList []nodesCoordinator.Validator, signingBitmap []byte, accumulatedFees *big.Int, @@ -1036,7 +1069,7 @@ func (vs *validatorStatistics) updateValidatorInfoOnSuccessfulBlock( peerAcc.IncreaseNumSelectedInSuccessBlocks() newRating := peerAcc.GetRating() - isLeader := i == 0 + isLeader := bytes.Equal(leader.PubKey(), validatorList[i].PubKey()) validatorSigned := (signingBitmap[i/8] & (1 << (uint16(i) % 8))) != 0 actionType := vs.computeValidatorActionType(isLeader, validatorSigned) @@ -1167,6 +1200,11 @@ func (vs *validatorStatistics) getTempRating(s string) uint32 { } func (vs *validatorStatistics) display(validatorKey string) { + if log.GetLevel() != logger.LogTrace { + // do not need to load peer account if not log level trace + return + } + peerAcc, err := vs.loadPeerAccount([]byte(validatorKey)) if err != nil { log.Trace("display peer acc", "error", err) @@ -1195,7 +1233,7 @@ func (vs *validatorStatistics) decreaseAll( } log.Debug("ValidatorStatistics decreasing all", "shardID", shardID, "missedRounds", missedRounds) - consensusGroupSize := vs.nodesCoordinator.ConsensusGroupSize(shardID) + consensusGroupSize := vs.nodesCoordinator.ConsensusGroupSizeForShardAndEpoch(shardID, epoch) validators, err := vs.nodesCoordinator.GetAllEligibleValidatorsPublicKeys(epoch) if err != nil { return err diff --git a/process/peer/process_test.go b/process/peer/process_test.go index 00af86cc735..32630292f3c 100644 --- a/process/peer/process_test.go +++ b/process/peer/process_test.go @@ -470,8 +470,8 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateComputeValidatorErrShouldEr arguments := createMockArguments() arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return nil, computeValidatorsErr + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return nil, nil, computeValidatorsErr }, } validatorStatistics, _ := peer.NewValidatorStatisticsProcessor(arguments) @@ -495,9 +495,10 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateGetExistingAccountErr(t *te } arguments := createMockArguments() + validator := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{&shardingMocks.ValidatorMock{}}, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator, []nodesCoordinator.Validator{validator}, nil }, } arguments.PeerAdapter = adapter @@ -520,9 +521,10 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateGetExistingAccountInvalidTy } arguments := createMockArguments() + validator := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{&shardingMocks.ValidatorMock{}}, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator, []nodesCoordinator.Validator{validator}, nil }, } arguments.PeerAdapter = adapter @@ -564,9 +566,11 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateGetHeaderError(t *testing.T }, nil }, } + + validator1 := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{&shardingMocks.ValidatorMock{}, &shardingMocks.ValidatorMock{}}, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator1, []nodesCoordinator.Validator{validator1, &shardingMocks.ValidatorMock{}}, nil }, } arguments.ShardCoordinator = shardCoordinatorMock @@ -620,9 +624,15 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateCallsIncrease(t *testing.T) }, nil }, } + + validator1 := &shardingMocks.ValidatorMock{ + PubKeyCalled: func() []byte { + return []byte("pk1") + }, + } arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{&shardingMocks.ValidatorMock{}, &shardingMocks.ValidatorMock{}}, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator1, []nodesCoordinator.Validator{validator1, &shardingMocks.ValidatorMock{PubKeyCalled: func() []byte { return []byte("pk2") }}}, nil }, } arguments.ShardCoordinator = shardCoordinatorMock @@ -1292,9 +1302,11 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateCheckForMissedBlocksErr(t * }, nil }, } + + validator1 := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{&shardingMocks.ValidatorMock{}, &shardingMocks.ValidatorMock{}}, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator1, []nodesCoordinator.Validator{validator1, &shardingMocks.ValidatorMock{}}, nil }, } arguments.ShardCoordinator = shardCoordinatorMock @@ -1360,9 +1372,9 @@ func TestValidatorStatisticsProcessor_CheckForMissedBlocksNoMissedBlocks(t *test arguments.DataPool = dataRetrieverMock.NewPoolsHolderStub() arguments.StorageService = &storageStubs.ChainStorerStub{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { computeValidatorGroupCalled = true - return nil, nil + return nil, nil, nil }, } arguments.ShardCoordinator = shardCoordinatorMock @@ -1446,8 +1458,8 @@ func TestValidatorStatisticsProcessor_CheckForMissedBlocksErrOnComputeValidatorL arguments.DataPool = dataRetrieverMock.NewPoolsHolderStub() arguments.StorageService = &storageStubs.ChainStorerStub{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return nil, computeErr + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return nil, nil, computeErr }, } arguments.ShardCoordinator = shardCoordinatorMock @@ -1473,10 +1485,11 @@ func TestValidatorStatisticsProcessor_CheckForMissedBlocksErrOnDecrease(t *testi } arguments := createMockArguments() + validator1 := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{ - &shardingMocks.ValidatorMock{}, + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator1, []nodesCoordinator.Validator{ + validator1, }, nil }, } @@ -1507,14 +1520,15 @@ func TestValidatorStatisticsProcessor_CheckForMissedBlocksCallsDecrease(t *testi } arguments := createMockArguments() + validator := &shardingMocks.ValidatorMock{ + PubKeyCalled: func() []byte { + return pubKey + }, + } arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{ - &shardingMocks.ValidatorMock{ - PubKeyCalled: func() []byte { - return pubKey - }, - }, + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator, []nodesCoordinator.Validator{ + validator, }, nil }, } @@ -1558,10 +1572,11 @@ func TestValidatorStatisticsProcessor_CheckForMissedBlocksWithRoundDifferenceGre } arguments := createMockArguments() + validator := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, _ uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{ - &shardingMocks.ValidatorMock{}, + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, _ uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator, []nodesCoordinator.Validator{ + validator, }, nil }, GetAllEligibleValidatorsPublicKeysCalled: func(_ uint32) (map[uint32][][]byte, error) { @@ -1617,10 +1632,11 @@ func TestValidatorStatisticsProcessor_CheckForMissedBlocksWithRoundDifferenceGre } arguments := createMockArguments() + validator := &shardingMocks.ValidatorMock{} arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, _ uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{ - &shardingMocks.ValidatorMock{}, + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, _ uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator, []nodesCoordinator.Validator{ + validator, }, nil }, GetAllEligibleValidatorsPublicKeysCalled: func(_ uint32) (map[uint32][][]byte, error) { @@ -1819,13 +1835,13 @@ func DoComputeMissingBlocks( arguments := createMockArguments() arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, _ uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return consensus, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, _ uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return consensus[0], consensus, nil }, GetAllEligibleValidatorsPublicKeysCalled: func(_ uint32) (map[uint32][][]byte, error) { return validatorPublicKeys, nil }, - ConsensusGroupSizeCalled: func(uint32) int { + ConsensusGroupSizeCalled: func(uint32, uint32) int { return consensusGroupSize }, GetValidatorWithPublicKeyCalled: func(publicKey []byte) (nodesCoordinator.Validator, uint32, error) { @@ -1894,14 +1910,18 @@ func TestValidatorStatisticsProcessor_UpdatePeerStateCallsPubKeyForValidator(t * pubKeyCalled := false arguments := createMockArguments() + validator := &shardingMocks.ValidatorMock{ + PubKeyCalled: func() []byte { + pubKeyCalled = true + return make([]byte, 0) + }, + } arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { - return []nodesCoordinator.Validator{&shardingMocks.ValidatorMock{ - PubKeyCalled: func() []byte { - pubKeyCalled = true - return make([]byte, 0) - }, - }, &shardingMocks.ValidatorMock{}}, nil + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { + return validator, []nodesCoordinator.Validator{ + validator, + &shardingMocks.ValidatorMock{}, + }, nil }, } arguments.DataPool = &dataRetrieverMock.PoolsHolderStub{ @@ -2385,6 +2405,40 @@ func TestValidatorStatistics_ProcessValidatorInfosEndOfEpochWithLargeValidatorFa assert.Equal(t, rater.MinRating, vi.GetShardValidatorsInfoMap()[0][0].GetTempRating()) } +func TestValidatorStatistics_ProcessRatingsEndOfEpochAfterEquivalentProofsShouldEarlyExit(t *testing.T) { + t.Parallel() + + arguments := createMockArguments() + rater := createMockRater() + rater.GetSignedBlocksThresholdCalled = func() float32 { + return 0.025 // would have passed the `if computedThreshold <= signedThreshold` condition + } + rater.RevertIncreaseProposerCalled = func(shardId uint32, rating uint32, nrReverts uint32) uint32 { + require.Fail(t, "should have not been called") + return 0 + } + arguments.Rater = rater + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + updateArgumentsWithNeeded(arguments) + + tempRating := uint32(5000) + validatorSuccess := uint32(2) + validatorIgnored := uint32(90) + validatorFailure := uint32(8) + + vi := state.NewShardValidatorsInfoMap() + _ = vi.Add(createMockValidatorInfo(core.MetachainShardId, tempRating, validatorSuccess, validatorIgnored, validatorFailure)) + + validatorStatistics, _ := peer.NewValidatorStatisticsProcessor(arguments) + err := validatorStatistics.ProcessRatingsEndOfEpoch(vi, 1) + assert.Nil(t, err) + assert.Equal(t, tempRating, vi.GetShardValidatorsInfoMap()[core.MetachainShardId][0].GetTempRating()) // no change +} + func TestValidatorsProvider_PeerAccoutToValidatorInfo(t *testing.T) { t.Parallel() @@ -2606,13 +2660,13 @@ func createUpdateTestArgs(consensusGroup map[string][]nodesCoordinator.Validator arguments.PeerAdapter = adapter arguments.NodesCoordinator = &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) { + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { key := fmt.Sprintf(consensusGroupFormat, string(randomness), round, shardId, epoch) validatorsArray, ok := consensusGroup[key] if !ok { - return nil, process.ErrEmptyConsensusGroup + return nil, nil, process.ErrEmptyConsensusGroup } - return validatorsArray, nil + return validatorsArray[0], validatorsArray, nil }, } return arguments diff --git a/process/rating/peerHonesty/peerHonesty_test.go b/process/rating/peerHonesty/peerHonesty_test.go index 73ca45e2623..0d7cf263ca6 100644 --- a/process/rating/peerHonesty/peerHonesty_test.go +++ b/process/rating/peerHonesty/peerHonesty_test.go @@ -7,9 +7,12 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) @@ -44,7 +47,7 @@ func TestNewP2pPeerHonesty_NilBlacklistedPkCacheShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( createMockPeerHonestyConfig(), nil, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -59,7 +62,7 @@ func TestNewP2pPeerHonesty_InvalidDecayCoefficientShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -74,7 +77,7 @@ func TestNewP2pPeerHonesty_InvalidDecayUpdateIntervalShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -89,7 +92,7 @@ func TestNewP2pPeerHonesty_InvalidMinScoreShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -104,7 +107,7 @@ func TestNewP2pPeerHonesty_InvalidMaxScoreShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -119,7 +122,7 @@ func TestNewP2pPeerHonesty_InvalidUnitValueShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -134,7 +137,7 @@ func TestNewP2pPeerHonesty_InvalidBadPeerThresholdShouldErr(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.True(t, check.IfNil(pph)) @@ -148,7 +151,7 @@ func TestNewP2pPeerHonesty_ShouldWork(t *testing.T) { pph, err := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, ) assert.False(t, check.IfNil(pph)) @@ -167,7 +170,7 @@ func TestP2pPeerHonesty_Close(t *testing.T) { pph, _ := NewP2pPeerHonestyWithCustomExecuteDelayFunction( cfg, &testscommon.TimeCacheStub{}, - &testscommon.CacherStub{}, + &cache.CacherStub{}, handler, ) @@ -189,7 +192,7 @@ func TestP2pPeerHonesty_ChangeScoreShouldWork(t *testing.T) { pph, _ := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -210,7 +213,7 @@ func TestP2pPeerHonesty_DoubleChangeScoreShouldWork(t *testing.T) { pph, _ := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -243,7 +246,7 @@ func TestP2pPeerHonesty_CheckBlacklistNotBlacklisted(t *testing.T) { return nil }, }, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -275,7 +278,7 @@ func TestP2pPeerHonesty_CheckBlacklistMaxScoreReached(t *testing.T) { return nil }, }, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -310,7 +313,7 @@ func TestP2pPeerHonesty_CheckBlacklistMinScoreReached(t *testing.T) { return nil }, }, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -345,7 +348,7 @@ func TestP2pPeerHonesty_CheckBlacklistHasShouldNotCallUpsert(t *testing.T) { return nil }, }, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -374,7 +377,7 @@ func TestP2pPeerHonesty_CheckBlacklistUpsertErrorsShouldWork(t *testing.T) { return errors.New("expected error") }, }, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" @@ -392,7 +395,7 @@ func TestP2pPeerHonesty_ApplyDecay(t *testing.T) { pph, _ := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pks := []string{"pkMin", "pkMax", "pkNearZero", "pkZero", "pkValue"} @@ -422,7 +425,7 @@ func TestP2pPeerHonesty_ApplyDecayWillEventuallyGoTheScoreToZero(t *testing.T) { pph, _ := NewP2pPeerHonesty( cfg, &testscommon.TimeCacheStub{}, - testscommon.NewCacherMock(), + cache.NewCacherMock(), ) pk := "pk" diff --git a/process/rating/ratingsData.go b/process/rating/ratingsData.go index 5e0b34ce75b..d602745e088 100644 --- a/process/rating/ratingsData.go +++ b/process/rating/ratingsData.go @@ -3,19 +3,32 @@ package rating import ( "fmt" "math" + "sort" + "sync" + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/statusHandler" + "golang.org/x/exp/slices" ) var _ process.RatingsInfoHandler = (*RatingsData)(nil) -const milisecondsInHour = 3600 * 1000 +const millisecondsInHour = 3600 * 1000 + +type ratingsStepsData struct { + enableEpoch uint32 + shardRatingsStepData process.RatingsStepHandler + metaRatingsStepData process.RatingsStepHandler +} type computeRatingStepArg struct { shardSize uint32 consensusSize uint32 - roundTimeMilis uint64 + roundTimeMillis uint64 startRating uint32 maxRating uint32 hoursToMaxRatingFromStartRating uint32 @@ -27,27 +40,37 @@ type computeRatingStepArg struct { // RatingsData will store information about ratingsComputation type RatingsData struct { - startRating uint32 - maxRating uint32 - minRating uint32 - signedBlocksThreshold float32 - metaRatingsStepData process.RatingsStepHandler - shardRatingsStepData process.RatingsStepHandler - selectionChances []process.SelectionChance + startRating uint32 + maxRating uint32 + minRating uint32 + signedBlocksThreshold float32 + currentRatingsStepData ratingsStepsData + ratingsStepsConfig []ratingsStepsData + selectionChances []process.SelectionChance + chainParametersHandler process.ChainParametersHandler + ratingsSetup config.RatingsConfig + roundDurationInMilliseconds uint64 + mutConfiguration sync.RWMutex + statusHandler core.AppStatusHandler + mutStatusHandler sync.RWMutex } // RatingsDataArg contains information for the creation of the new ratingsData type RatingsDataArg struct { - Config config.RatingsConfig - ShardConsensusSize uint32 - MetaConsensusSize uint32 - ShardMinNodes uint32 - MetaMinNodes uint32 - RoundDurationMiliseconds uint64 + EpochNotifier process.EpochNotifier + Config config.RatingsConfig + ChainParametersHolder process.ChainParametersHandler + RoundDurationMilliseconds uint64 } // NewRatingsData creates a new RatingsData instance func NewRatingsData(args RatingsDataArg) (*RatingsData, error) { + if check.IfNil(args.EpochNotifier) { + return nil, process.ErrNilEpochNotifier + } + if check.IfNil(args.ChainParametersHolder) { + return nil, process.ErrNilChainParametersHandler + } ratingsConfig := args.Config err := verifyRatingsConfig(ratingsConfig) if err != nil { @@ -62,51 +85,239 @@ func NewRatingsData(args RatingsDataArg) (*RatingsData, error) { }) } + // avoid any invalid configuration where ratings are not sorted by epoch + slices.SortFunc(ratingsConfig.ShardChain.RatingStepsByEpoch, func(a, b config.RatingSteps) int { + return int(a.EnableEpoch) - int(b.EnableEpoch) + }) + slices.SortFunc(ratingsConfig.MetaChain.RatingStepsByEpoch, func(a, b config.RatingSteps) int { + return int(a.EnableEpoch) - int(b.EnableEpoch) + }) + + if !checkForEpochZeroConfiguration(args) { + return nil, process.ErrMissingConfigurationForEpochZero + } + + currentChainParameters := args.ChainParametersHolder.CurrentChainParameters() + shardChainRatingSteps, _ := getRatingStepsForEpoch(args.EpochNotifier.CurrentEpoch(), ratingsConfig.ShardChain.RatingStepsByEpoch) arg := computeRatingStepArg{ - shardSize: args.ShardMinNodes, - consensusSize: args.ShardConsensusSize, - roundTimeMilis: args.RoundDurationMiliseconds, + shardSize: currentChainParameters.ShardMinNumNodes, + consensusSize: currentChainParameters.ShardConsensusGroupSize, + roundTimeMillis: args.RoundDurationMilliseconds, startRating: ratingsConfig.General.StartRating, maxRating: ratingsConfig.General.MaxRating, - hoursToMaxRatingFromStartRating: ratingsConfig.ShardChain.HoursToMaxRatingFromStartRating, - proposerDecreaseFactor: ratingsConfig.ShardChain.ProposerDecreaseFactor, - validatorDecreaseFactor: ratingsConfig.ShardChain.ValidatorDecreaseFactor, - consecutiveMissedBlocksPenalty: ratingsConfig.ShardChain.ConsecutiveMissedBlocksPenalty, - proposerValidatorImportance: ratingsConfig.ShardChain.ProposerValidatorImportance, + hoursToMaxRatingFromStartRating: shardChainRatingSteps.HoursToMaxRatingFromStartRating, + proposerDecreaseFactor: shardChainRatingSteps.ProposerDecreaseFactor, + validatorDecreaseFactor: shardChainRatingSteps.ValidatorDecreaseFactor, + consecutiveMissedBlocksPenalty: shardChainRatingSteps.ConsecutiveMissedBlocksPenalty, + proposerValidatorImportance: shardChainRatingSteps.ProposerValidatorImportance, } shardRatingStep, err := computeRatingStep(arg) if err != nil { return nil, err } + metaChainRatingSteps, _ := getRatingStepsForEpoch(args.EpochNotifier.CurrentEpoch(), ratingsConfig.MetaChain.RatingStepsByEpoch) arg = computeRatingStepArg{ - shardSize: args.MetaMinNodes, - consensusSize: args.MetaConsensusSize, - roundTimeMilis: args.RoundDurationMiliseconds, + shardSize: currentChainParameters.MetachainMinNumNodes, + consensusSize: currentChainParameters.MetachainConsensusGroupSize, + roundTimeMillis: args.RoundDurationMilliseconds, startRating: ratingsConfig.General.StartRating, maxRating: ratingsConfig.General.MaxRating, - hoursToMaxRatingFromStartRating: ratingsConfig.MetaChain.HoursToMaxRatingFromStartRating, - proposerDecreaseFactor: ratingsConfig.MetaChain.ProposerDecreaseFactor, - validatorDecreaseFactor: ratingsConfig.MetaChain.ValidatorDecreaseFactor, - consecutiveMissedBlocksPenalty: ratingsConfig.MetaChain.ConsecutiveMissedBlocksPenalty, - proposerValidatorImportance: ratingsConfig.MetaChain.ProposerValidatorImportance, + hoursToMaxRatingFromStartRating: metaChainRatingSteps.HoursToMaxRatingFromStartRating, + proposerDecreaseFactor: metaChainRatingSteps.ProposerDecreaseFactor, + validatorDecreaseFactor: metaChainRatingSteps.ValidatorDecreaseFactor, + consecutiveMissedBlocksPenalty: metaChainRatingSteps.ConsecutiveMissedBlocksPenalty, + proposerValidatorImportance: metaChainRatingSteps.ProposerValidatorImportance, } metaRatingStep, err := computeRatingStep(arg) if err != nil { return nil, err } - return &RatingsData{ - startRating: ratingsConfig.General.StartRating, - maxRating: ratingsConfig.General.MaxRating, - minRating: ratingsConfig.General.MinRating, - signedBlocksThreshold: ratingsConfig.General.SignedBlocksThreshold, - metaRatingsStepData: metaRatingStep, - shardRatingsStepData: shardRatingStep, - selectionChances: chances, + ratingsConfigValue := ratingsStepsData{ + enableEpoch: args.EpochNotifier.CurrentEpoch(), + shardRatingsStepData: shardRatingStep, + metaRatingsStepData: metaRatingStep, + } + + ratingData := &RatingsData{ + startRating: ratingsConfig.General.StartRating, + maxRating: ratingsConfig.General.MaxRating, + minRating: ratingsConfig.General.MinRating, + signedBlocksThreshold: ratingsConfig.General.SignedBlocksThreshold, + currentRatingsStepData: ratingsConfigValue, + selectionChances: chances, + chainParametersHandler: args.ChainParametersHolder, + ratingsSetup: ratingsConfig, + roundDurationInMilliseconds: args.RoundDurationMilliseconds, + statusHandler: statusHandler.NewNilStatusHandler(), + } + + err = ratingData.computeRatingStepsConfig(args.ChainParametersHolder.AllChainParameters()) + if err != nil { + return nil, err + } + + args.EpochNotifier.RegisterNotifyHandler(ratingData) + + return ratingData, nil +} + +func checkForEpochZeroConfiguration(args RatingsDataArg) bool { + _, foundShardChainRatingSteps := getRatingStepsForEpoch(0, args.Config.ShardChain.RatingStepsByEpoch) + _, foundMetaChainRatingSteps := getRatingStepsForEpoch(0, args.Config.MetaChain.RatingStepsByEpoch) + _, foundChainParams := getChainParamsForEpoch(0, args.ChainParametersHolder.AllChainParameters()) + + return foundShardChainRatingSteps && foundMetaChainRatingSteps && foundChainParams +} + +func (rd *RatingsData) computeRatingStepsConfig(chainParamsList []config.ChainParametersByEpochConfig) error { + if len(chainParamsList) == 0 { + return process.ErrEmptyChainParametersConfiguration + } + + // there are multiple scenarios when ratingSteps can change: + // 1. chain parameters change in a specific epoch + // 2. shard/meta rating steps change in a specific epoch + // thus we extract first all configured epochs in a map, from all meta, chard and chain parameters + // this way we make sure that for each activation epoch we got the proper config, taking all params into consideration + configuredEpochsMap := make(map[uint32]struct{}) + for _, ratingStepsForEpoch := range rd.ratingsSetup.ShardChain.RatingStepsByEpoch { + configuredEpochsMap[ratingStepsForEpoch.EnableEpoch] = struct{}{} + } + + for _, ratingStepsForEpoch := range rd.ratingsSetup.MetaChain.RatingStepsByEpoch { + configuredEpochsMap[ratingStepsForEpoch.EnableEpoch] = struct{}{} + } + + for _, chainParams := range chainParamsList { + configuredEpochsMap[chainParams.EnableEpoch] = struct{}{} + } + + ratingsStepsConfig := make([]ratingsStepsData, 0) + for epoch := range configuredEpochsMap { + configForEpoch, err := rd.computeRatingStepsConfigForEpoch(epoch, chainParamsList) + if err != nil { + return err + } + + ratingsStepsConfig = append(ratingsStepsConfig, configForEpoch) + } + + // sort the config values descending + sort.SliceStable(ratingsStepsConfig, func(i, j int) bool { + return ratingsStepsConfig[i].enableEpoch > ratingsStepsConfig[j].enableEpoch + }) + + earliestConfig := ratingsStepsConfig[len(ratingsStepsConfig)-1] + if earliestConfig.enableEpoch != 0 { + return process.ErrMissingConfigurationForEpochZero + } + + rd.ratingsStepsConfig = ratingsStepsConfig + + return nil +} + +func (rd *RatingsData) computeRatingStepsConfigForEpoch( + epoch uint32, + chainParamsList []config.ChainParametersByEpochConfig, +) (ratingsStepsData, error) { + chainParams, _ := getChainParamsForEpoch(epoch, chainParamsList) + + shardChainRatingSteps, _ := getRatingStepsForEpoch(epoch, rd.ratingsSetup.ShardChain.RatingStepsByEpoch) + shardRatingsStepsArgs := computeRatingStepArg{ + shardSize: chainParams.ShardMinNumNodes, + consensusSize: chainParams.ShardConsensusGroupSize, + roundTimeMillis: rd.roundDurationInMilliseconds, + startRating: rd.ratingsSetup.General.StartRating, + maxRating: rd.ratingsSetup.General.MaxRating, + hoursToMaxRatingFromStartRating: shardChainRatingSteps.HoursToMaxRatingFromStartRating, + proposerDecreaseFactor: shardChainRatingSteps.ProposerDecreaseFactor, + validatorDecreaseFactor: shardChainRatingSteps.ValidatorDecreaseFactor, + consecutiveMissedBlocksPenalty: shardChainRatingSteps.ConsecutiveMissedBlocksPenalty, + proposerValidatorImportance: shardChainRatingSteps.ProposerValidatorImportance, + } + shardRatingsStepData, err := computeRatingStep(shardRatingsStepsArgs) + if err != nil { + return ratingsStepsData{}, fmt.Errorf("%w while computing shard rating steps for epoch %d", err, chainParams.EnableEpoch) + } + + metaChainRatingSteps, _ := getRatingStepsForEpoch(epoch, rd.ratingsSetup.MetaChain.RatingStepsByEpoch) + metaRatingsStepsArgs := computeRatingStepArg{ + shardSize: chainParams.MetachainMinNumNodes, + consensusSize: chainParams.MetachainConsensusGroupSize, + roundTimeMillis: rd.roundDurationInMilliseconds, + startRating: rd.ratingsSetup.General.StartRating, + maxRating: rd.ratingsSetup.General.MaxRating, + hoursToMaxRatingFromStartRating: metaChainRatingSteps.HoursToMaxRatingFromStartRating, + proposerDecreaseFactor: metaChainRatingSteps.ProposerDecreaseFactor, + validatorDecreaseFactor: metaChainRatingSteps.ValidatorDecreaseFactor, + consecutiveMissedBlocksPenalty: metaChainRatingSteps.ConsecutiveMissedBlocksPenalty, + proposerValidatorImportance: metaChainRatingSteps.ProposerValidatorImportance, + } + metaRatingsStepData, err := computeRatingStep(metaRatingsStepsArgs) + if err != nil { + return ratingsStepsData{}, fmt.Errorf("%w while computing metachain rating steps for epoch %d", err, chainParams.EnableEpoch) + } + + return ratingsStepsData{ + enableEpoch: epoch, + shardRatingsStepData: shardRatingsStepData, + metaRatingsStepData: metaRatingsStepData, }, nil } +// EpochConfirmed will be called whenever a new epoch is confirmed +func (rd *RatingsData) EpochConfirmed(epoch uint32, _ uint64) { + log.Debug("RatingsData - epoch confirmed", "epoch", epoch) + + rd.mutConfiguration.Lock() + defer rd.mutConfiguration.Unlock() + + newVersion, err := rd.getMatchingVersion(epoch) + if err != nil { + log.Error("RatingsData.EpochConfirmed - cannot get matching version", "epoch", epoch, "error", err) + return + } + + if rd.currentRatingsStepData.enableEpoch == newVersion.enableEpoch { + return + } + + rd.currentRatingsStepData = newVersion + + log.Debug("updated shard ratings step data", + "epoch", epoch, + "proposer increase rating step", newVersion.shardRatingsStepData.ProposerIncreaseRatingStep(), + "proposer decrease rating step", newVersion.shardRatingsStepData.ProposerDecreaseRatingStep(), + "validator increase rating step", newVersion.shardRatingsStepData.ValidatorIncreaseRatingStep(), + "validator decrease rating step", newVersion.shardRatingsStepData.ValidatorDecreaseRatingStep(), + ) + + log.Debug("updated metachain ratings step data", + "epoch", epoch, + "proposer increase rating step", newVersion.metaRatingsStepData.ProposerIncreaseRatingStep(), + "proposer decrease rating step", newVersion.metaRatingsStepData.ProposerDecreaseRatingStep(), + "validator increase rating step", newVersion.metaRatingsStepData.ValidatorIncreaseRatingStep(), + "validator decrease rating step", newVersion.metaRatingsStepData.ValidatorDecreaseRatingStep(), + ) + + rd.updateRatingsMetrics(epoch) +} + +func (rd *RatingsData) getMatchingVersion(epoch uint32) (ratingsStepsData, error) { + // the config values are sorted in descending order, so the matching version is the first one whose enable epoch is less or equal than the provided epoch + for _, ratingsStepConfig := range rd.ratingsStepsConfig { + if ratingsStepConfig.enableEpoch <= epoch { + return ratingsStepConfig, nil + } + } + + // the code should never reach this point, since the config values are checked on the constructor + return ratingsStepsData{}, process.ErrNoMatchingConfigForProvidedEpoch +} + func verifyRatingsConfig(settings config.RatingsConfig) error { if settings.General.MinRating < 1 { return process.ErrMinRatingSmallerThanOne @@ -129,43 +340,52 @@ func verifyRatingsConfig(settings config.RatingsConfig) error { process.ErrSignedBlocksThresholdNotBetweenZeroAndOne, settings.General.SignedBlocksThreshold) } - if settings.ShardChain.HoursToMaxRatingFromStartRating == 0 { - return fmt.Errorf("%w hoursToMaxRatingFromStartRating: shardChain", - process.ErrHoursToMaxRatingFromStartRatingZero) - } - if settings.MetaChain.HoursToMaxRatingFromStartRating == 0 { - return fmt.Errorf("%w hoursToMaxRatingFromStartRating: metachain", - process.ErrHoursToMaxRatingFromStartRatingZero) - } - if settings.MetaChain.ConsecutiveMissedBlocksPenalty < 1 { - return fmt.Errorf("%w: metaChain consecutiveMissedBlocksPenalty: %v", - process.ErrConsecutiveMissedBlocksPenaltyLowerThanOne, - settings.MetaChain.ConsecutiveMissedBlocksPenalty) - } - if settings.ShardChain.ConsecutiveMissedBlocksPenalty < 1 { - return fmt.Errorf("%w: shardChain consecutiveMissedBlocksPenalty: %v", - process.ErrConsecutiveMissedBlocksPenaltyLowerThanOne, - settings.ShardChain.ConsecutiveMissedBlocksPenalty) + err := checkRatingStepsByEpochConfigForDest(settings.ShardChain.RatingStepsByEpoch, "shardChain") + if err != nil { + return err } - if settings.ShardChain.ProposerDecreaseFactor > -1 || settings.ShardChain.ValidatorDecreaseFactor > -1 { - return fmt.Errorf("%w: shardChain decrease steps - proposer: %v, validator: %v", - process.ErrDecreaseRatingsStepMoreThanMinusOne, - settings.ShardChain.ProposerDecreaseFactor, - settings.ShardChain.ValidatorDecreaseFactor) + + return checkRatingStepsByEpochConfigForDest(settings.MetaChain.RatingStepsByEpoch, "metaChain") +} + +func checkRatingStepsByEpochConfigForDest(ratingStepsByEpoch []config.RatingSteps, configDestination string) error { + if len(ratingStepsByEpoch) == 0 { + return fmt.Errorf("%w for %s", + process.ErrInvalidRatingsConfig, + configDestination) } - if settings.MetaChain.ProposerDecreaseFactor > -1 || settings.MetaChain.ValidatorDecreaseFactor > -1 { - return fmt.Errorf("%w: metachain decrease steps - proposer: %v, validator: %v", - process.ErrDecreaseRatingsStepMoreThanMinusOne, - settings.MetaChain.ProposerDecreaseFactor, - settings.MetaChain.ValidatorDecreaseFactor) + + for _, ratingStepsForEpoch := range ratingStepsByEpoch { + if ratingStepsForEpoch.HoursToMaxRatingFromStartRating == 0 { + return fmt.Errorf("%w hoursToMaxRatingFromStartRating: %s, epoch: %d", + process.ErrHoursToMaxRatingFromStartRatingZero, + configDestination, + ratingStepsForEpoch.EnableEpoch) + } + if ratingStepsForEpoch.ConsecutiveMissedBlocksPenalty < 1 { + return fmt.Errorf("%w: %s consecutiveMissedBlocksPenalty: %v, epoch: %d", + process.ErrConsecutiveMissedBlocksPenaltyLowerThanOne, + configDestination, + ratingStepsForEpoch.ConsecutiveMissedBlocksPenalty, + ratingStepsForEpoch.EnableEpoch) + } + if ratingStepsForEpoch.ProposerDecreaseFactor > -1 || ratingStepsForEpoch.ValidatorDecreaseFactor > -1 { + return fmt.Errorf("%w: %s decrease steps - proposer: %v, validator: %v, epoch: %d", + process.ErrDecreaseRatingsStepMoreThanMinusOne, + configDestination, + ratingStepsForEpoch.ProposerDecreaseFactor, + ratingStepsForEpoch.ValidatorDecreaseFactor, + ratingStepsForEpoch.EnableEpoch) + } } + return nil } func computeRatingStep( arg computeRatingStepArg, ) (process.RatingsStepHandler, error) { - blocksProducedInHours := uint64(arg.hoursToMaxRatingFromStartRating*milisecondsInHour) / arg.roundTimeMilis + blocksProducedInHours := uint64(arg.hoursToMaxRatingFromStartRating*millisecondsInHour) / arg.roundTimeMillis ratingDifference := arg.maxRating - arg.startRating proposerProbability := float32(blocksProducedInHours) / float32(arg.shardSize) @@ -210,40 +430,113 @@ func computeRatingStep( // StartRating will return the start rating func (rd *RatingsData) StartRating() uint32 { + // no need for mutex protection since this value is only set on constructor return rd.startRating } // MaxRating will return the max rating func (rd *RatingsData) MaxRating() uint32 { + // no need for mutex protection since this value is only set on constructor return rd.maxRating } // MinRating will return the min rating func (rd *RatingsData) MinRating() uint32 { + // no need for mutex protection since this value is only set on constructor return rd.minRating } // SignedBlocksThreshold will return the signed blocks threshold func (rd *RatingsData) SignedBlocksThreshold() float32 { + // no need for mutex protection since this value is only set on constructor return rd.signedBlocksThreshold } // SelectionChances will return the array of selectionChances and thresholds func (rd *RatingsData) SelectionChances() []process.SelectionChance { + // no need for mutex protection since this value is only set on constructor return rd.selectionChances } // MetaChainRatingsStepHandler returns the RatingsStepHandler used for the Metachain func (rd *RatingsData) MetaChainRatingsStepHandler() process.RatingsStepHandler { - return rd.metaRatingsStepData + rd.mutConfiguration.RLock() + defer rd.mutConfiguration.RUnlock() + + return rd.currentRatingsStepData.metaRatingsStepData } // ShardChainRatingsStepHandler returns the RatingsStepHandler used for the ShardChains func (rd *RatingsData) ShardChainRatingsStepHandler() process.RatingsStepHandler { - return rd.shardRatingsStepData + rd.mutConfiguration.RLock() + defer rd.mutConfiguration.RUnlock() + + return rd.currentRatingsStepData.shardRatingsStepData +} + +// SetStatusHandler sets the provided status handler if not nil +func (rd *RatingsData) SetStatusHandler(handler core.AppStatusHandler) error { + if check.IfNil(handler) { + return process.ErrNilAppStatusHandler + } + + rd.mutStatusHandler.Lock() + rd.statusHandler = handler + rd.mutStatusHandler.Unlock() + + return nil } // IsInterfaceNil returns true if underlying object is nil func (rd *RatingsData) IsInterfaceNil() bool { return rd == nil } + +func (rd *RatingsData) updateRatingsMetrics(epoch uint32) { + rd.mutStatusHandler.RLock() + defer rd.mutStatusHandler.RUnlock() + + currentShardRatingsStep, _ := getRatingStepsForEpoch(epoch, rd.ratingsSetup.ShardChain.RatingStepsByEpoch) + rd.statusHandler.SetUInt64Value(common.MetricRatingsShardChainHoursToMaxRatingFromStartRating, uint64(currentShardRatingsStep.HoursToMaxRatingFromStartRating)) + rd.statusHandler.SetStringValue(common.MetricRatingsShardChainProposerValidatorImportance, fmt.Sprintf("%f", currentShardRatingsStep.ProposerValidatorImportance)) + rd.statusHandler.SetStringValue(common.MetricRatingsShardChainProposerDecreaseFactor, fmt.Sprintf("%f", currentShardRatingsStep.ProposerDecreaseFactor)) + rd.statusHandler.SetStringValue(common.MetricRatingsShardChainValidatorDecreaseFactor, fmt.Sprintf("%f", currentShardRatingsStep.ValidatorDecreaseFactor)) + rd.statusHandler.SetStringValue(common.MetricRatingsShardChainConsecutiveMissedBlocksPenalty, fmt.Sprintf("%f", currentShardRatingsStep.ConsecutiveMissedBlocksPenalty)) + + currentMetaRatingsStep, _ := getRatingStepsForEpoch(epoch, rd.ratingsSetup.MetaChain.RatingStepsByEpoch) + rd.statusHandler.SetUInt64Value(common.MetricRatingsMetaChainHoursToMaxRatingFromStartRating, uint64(currentMetaRatingsStep.HoursToMaxRatingFromStartRating)) + rd.statusHandler.SetStringValue(common.MetricRatingsMetaChainProposerValidatorImportance, fmt.Sprintf("%f", currentMetaRatingsStep.ProposerValidatorImportance)) + rd.statusHandler.SetStringValue(common.MetricRatingsMetaChainProposerDecreaseFactor, fmt.Sprintf("%f", currentMetaRatingsStep.ProposerDecreaseFactor)) + rd.statusHandler.SetStringValue(common.MetricRatingsMetaChainValidatorDecreaseFactor, fmt.Sprintf("%f", currentMetaRatingsStep.ValidatorDecreaseFactor)) + rd.statusHandler.SetStringValue(common.MetricRatingsMetaChainConsecutiveMissedBlocksPenalty, fmt.Sprintf("%f", currentMetaRatingsStep.ConsecutiveMissedBlocksPenalty)) +} + +func getRatingStepsForEpoch(epoch uint32, ratingStepsPerEpoch []config.RatingSteps) (config.RatingSteps, bool) { + var ratingSteps config.RatingSteps + found := false + for _, ratingStepsForEpoch := range ratingStepsPerEpoch { + if ratingStepsForEpoch.EnableEpoch <= epoch { + ratingSteps = ratingStepsForEpoch + found = true + } + } + + return ratingSteps, found +} + +func getChainParamsForEpoch(epoch uint32, chainParamsList []config.ChainParametersByEpochConfig) (config.ChainParametersByEpochConfig, bool) { + slices.SortFunc(chainParamsList, func(a, b config.ChainParametersByEpochConfig) int { + return int(a.EnableEpoch) - int(b.EnableEpoch) + }) + + var chainParams config.ChainParametersByEpochConfig + found := false + for _, chainParamsForEpoch := range chainParamsList { + if chainParamsForEpoch.EnableEpoch <= epoch { + chainParams = chainParamsForEpoch + found = true + } + } + + return chainParams, found +} diff --git a/process/rating/ratingsData_test.go b/process/rating/ratingsData_test.go index 22ccd960aeb..4ac3ced60fd 100644 --- a/process/rating/ratingsData_test.go +++ b/process/rating/ratingsData_test.go @@ -8,6 +8,9 @@ import ( "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" + "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" + "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -26,21 +29,46 @@ const ( signedBlocksThreshold = 0.025 consecutiveMissedBlocksPenalty = 1.1 - shardMinNodes = 6 - shardConsensusSize = 3 - metaMinNodes = 6 - metaConsensusSize = 6 - roundDurationMiliseconds = 6000 + shardMinNodes = 6 + shardConsensusSize = 3 + metaMinNodes = 6 + metaConsensusSize = 6 + roundDurationMilliseconds = 6000 ) -func createDymmyRatingsData() RatingsDataArg { +func createDummyRatingsData() RatingsDataArg { return RatingsDataArg{ - Config: config.RatingsConfig{}, - ShardConsensusSize: shardConsensusSize, - MetaConsensusSize: metaConsensusSize, - ShardMinNodes: shardMinNodes, - MetaMinNodes: metaMinNodes, - RoundDurationMiliseconds: roundDurationMiliseconds, + Config: config.RatingsConfig{}, + ChainParametersHolder: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + RoundDuration: 4000, + Hysteresis: 0.2, + EnableEpoch: 0, + ShardConsensusGroupSize: shardConsensusSize, + ShardMinNumNodes: shardMinNodes, + MetachainConsensusGroupSize: metaConsensusSize, + MetachainMinNumNodes: metaMinNodes, + Adaptivity: false, + } + }, + AllChainParametersCalled: func() []config.ChainParametersByEpochConfig { + return []config.ChainParametersByEpochConfig{ + { + RoundDuration: 4000, + Hysteresis: 0.2, + EnableEpoch: 0, + ShardConsensusGroupSize: shardConsensusSize, + ShardMinNumNodes: shardMinNodes, + MetachainConsensusGroupSize: metaConsensusSize, + MetachainMinNumNodes: metaMinNodes, + Adaptivity: false, + }, + } + }, + }, + RoundDurationMilliseconds: roundDurationMilliseconds, + EpochNotifier: &epochNotifier.EpochNotifierStub{}, } } @@ -59,30 +87,102 @@ func createDummyRatingsConfig() config.RatingsConfig { }, }, ShardChain: config.ShardChain{ - RatingSteps: config.RatingSteps{ - HoursToMaxRatingFromStartRating: 2, - ProposerValidatorImportance: 1, - ProposerDecreaseFactor: -4, - ValidatorDecreaseFactor: -4, - ConsecutiveMissedBlocksPenalty: consecutiveMissedBlocksPenalty, + RatingStepsByEpoch: []config.RatingSteps{ + { + HoursToMaxRatingFromStartRating: 2, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: consecutiveMissedBlocksPenalty, + EnableEpoch: 0, + }, }, }, MetaChain: config.MetaChain{ - RatingSteps: config.RatingSteps{ - HoursToMaxRatingFromStartRating: 2, - ProposerValidatorImportance: 1, - ProposerDecreaseFactor: -4, - ValidatorDecreaseFactor: -4, - ConsecutiveMissedBlocksPenalty: consecutiveMissedBlocksPenalty, + RatingStepsByEpoch: []config.RatingSteps{ + { + HoursToMaxRatingFromStartRating: 2, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: consecutiveMissedBlocksPenalty, + EnableEpoch: 0, + }, }, }, } } +func TestNewRatingsData_NilEpochNotifier(t *testing.T) { + t.Parallel() + + ratingsDataArg := createDummyRatingsData() + ratingsDataArg.EpochNotifier = nil + + ratingsData, err := NewRatingsData(ratingsDataArg) + + assert.Nil(t, ratingsData) + assert.True(t, errors.Is(err, process.ErrNilEpochNotifier)) +} + +func TestNewRatingsData_NilChainParametersHolder(t *testing.T) { + t.Parallel() + + ratingsDataArg := createDummyRatingsData() + ratingsDataArg.ChainParametersHolder = nil + + ratingsData, err := NewRatingsData(ratingsDataArg) + + assert.Nil(t, ratingsData) + assert.True(t, errors.Is(err, process.ErrNilChainParametersHandler)) +} + +func TestNewRatingsData_MissingConfigurationForEpoch0(t *testing.T) { + t.Parallel() + + ratingsDataArg := createDummyRatingsData() + ratingsDataArg.Config = createDummyRatingsConfig() + ratingsDataArg.Config.ShardChain.RatingStepsByEpoch[0].EnableEpoch = 37 + ratingsDataArg.Config.MetaChain.RatingStepsByEpoch[0].EnableEpoch = 37 + ratingsDataArg.ChainParametersHolder = &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + RoundDuration: 4000, + Hysteresis: 0.2, + EnableEpoch: 37, + ShardConsensusGroupSize: shardConsensusSize, + ShardMinNumNodes: shardMinNodes, + MetachainConsensusGroupSize: metaConsensusSize, + MetachainMinNumNodes: metaMinNodes, + Adaptivity: false, + } + }, + AllChainParametersCalled: func() []config.ChainParametersByEpochConfig { + return []config.ChainParametersByEpochConfig{ + { + RoundDuration: 4000, + Hysteresis: 0.2, + EnableEpoch: 37, + ShardConsensusGroupSize: shardConsensusSize, + ShardMinNumNodes: shardMinNodes, + MetachainConsensusGroupSize: metaConsensusSize, + MetachainMinNumNodes: metaMinNodes, + Adaptivity: false, + }, + } + }, + } + + ratingsData, err := NewRatingsData(ratingsDataArg) + + assert.Nil(t, ratingsData) + assert.True(t, errors.Is(err, process.ErrMissingConfigurationForEpochZero)) +} + func TestRatingsData_RatingsDataMinGreaterMaxShouldErr(t *testing.T) { t.Parallel() - ratingsDataArg := createDymmyRatingsData() + ratingsDataArg := createDummyRatingsData() ratingsConfig := createDummyRatingsConfig() ratingsConfig.General.MinRating = 10 ratingsConfig.General.MaxRating = 8 @@ -97,7 +197,7 @@ func TestRatingsData_RatingsDataMinGreaterMaxShouldErr(t *testing.T) { func TestRatingsData_RatingsDataMinSmallerThanOne(t *testing.T) { t.Parallel() - ratingsDataArg := createDymmyRatingsData() + ratingsDataArg := createDummyRatingsData() ratingsConfig := createDummyRatingsConfig() ratingsConfig.General.MinRating = 0 ratingsConfig.General.MaxRating = 8 @@ -111,7 +211,7 @@ func TestRatingsData_RatingsDataMinSmallerThanOne(t *testing.T) { func TestRatingsData_RatingsStartGreaterMaxShouldErr(t *testing.T) { t.Parallel() - ratingsDataArg := createDymmyRatingsData() + ratingsDataArg := createDummyRatingsData() ratingsConfig := createDummyRatingsConfig() ratingsConfig.General.MinRating = 10 ratingsConfig.General.MaxRating = 100 @@ -126,7 +226,7 @@ func TestRatingsData_RatingsStartGreaterMaxShouldErr(t *testing.T) { func TestRatingsData_RatingsStartLowerMinShouldErr(t *testing.T) { t.Parallel() - ratingsDataArg := createDymmyRatingsData() + ratingsDataArg := createDummyRatingsData() ratingsConfig := createDummyRatingsConfig() ratingsConfig.General.MinRating = 10 ratingsConfig.General.MaxRating = 100 @@ -141,7 +241,7 @@ func TestRatingsData_RatingsStartLowerMinShouldErr(t *testing.T) { func TestRatingsData_RatingsSignedBlocksThresholdNotBetweenZeroAndOneShouldErr(t *testing.T) { t.Parallel() - ratingsDataArg := createDymmyRatingsData() + ratingsDataArg := createDummyRatingsData() ratingsConfig := createDummyRatingsConfig() ratingsConfig.General.SignedBlocksThreshold = -0.1 ratingsDataArg.Config = ratingsConfig @@ -161,9 +261,9 @@ func TestRatingsData_RatingsSignedBlocksThresholdNotBetweenZeroAndOneShouldErr(t func TestRatingsData_RatingsConsecutiveMissedBlocksPenaltyLowerThanOneShouldErr(t *testing.T) { t.Parallel() - ratingsDataArg := createDymmyRatingsData() + ratingsDataArg := createDummyRatingsData() ratingsConfig := createDummyRatingsConfig() - ratingsConfig.MetaChain.ConsecutiveMissedBlocksPenalty = 0.9 + ratingsConfig.MetaChain.RatingStepsByEpoch[0].ConsecutiveMissedBlocksPenalty = 0.9 ratingsDataArg.Config = ratingsConfig ratingsData, err := NewRatingsData(ratingsDataArg) @@ -171,8 +271,8 @@ func TestRatingsData_RatingsConsecutiveMissedBlocksPenaltyLowerThanOneShouldErr( require.True(t, errors.Is(err, process.ErrConsecutiveMissedBlocksPenaltyLowerThanOne)) require.True(t, strings.Contains(err.Error(), "meta")) - ratingsConfig.MetaChain.ConsecutiveMissedBlocksPenalty = 1.99 - ratingsConfig.ShardChain.ConsecutiveMissedBlocksPenalty = 0.99 + ratingsConfig.MetaChain.RatingStepsByEpoch[0].ConsecutiveMissedBlocksPenalty = 1.99 + ratingsConfig.ShardChain.RatingStepsByEpoch[0].ConsecutiveMissedBlocksPenalty = 0.99 ratingsDataArg.Config = ratingsConfig ratingsData, err = NewRatingsData(ratingsDataArg) @@ -181,12 +281,43 @@ func TestRatingsData_RatingsConsecutiveMissedBlocksPenaltyLowerThanOneShouldErr( require.True(t, strings.Contains(err.Error(), "shard")) } +func TestRatingsData_EmptyRatingsConfig(t *testing.T) { + t.Parallel() + + t.Run("shard should error", func(t *testing.T) { + t.Parallel() + + ratingsDataArg := createDummyRatingsData() + ratingsConfig := createDummyRatingsConfig() + ratingsConfig.ShardChain = config.ShardChain{} + ratingsDataArg.Config = ratingsConfig + ratingsData, err := NewRatingsData(ratingsDataArg) + + require.Nil(t, ratingsData) + require.True(t, errors.Is(err, process.ErrInvalidRatingsConfig)) + require.True(t, strings.Contains(err.Error(), "shardChain")) + }) + t.Run("meta should error", func(t *testing.T) { + t.Parallel() + + ratingsDataArg := createDummyRatingsData() + ratingsConfig := createDummyRatingsConfig() + ratingsConfig.MetaChain = config.MetaChain{} + ratingsDataArg.Config = ratingsConfig + ratingsData, err := NewRatingsData(ratingsDataArg) + + require.Nil(t, ratingsData) + require.True(t, errors.Is(err, process.ErrInvalidRatingsConfig)) + require.True(t, strings.Contains(err.Error(), "metaChain")) + }) +} + func TestRatingsData_HoursToMaxRatingFromStartRatingZeroErr(t *testing.T) { t.Parallel() - ratingsDataArg := createDymmyRatingsData() + ratingsDataArg := createDummyRatingsData() ratingsConfig := createDummyRatingsConfig() - ratingsConfig.MetaChain.HoursToMaxRatingFromStartRating = 0 + ratingsConfig.MetaChain.RatingStepsByEpoch[0].HoursToMaxRatingFromStartRating = 0 ratingsDataArg.Config = ratingsConfig ratingsData, err := NewRatingsData(ratingsDataArg) @@ -197,9 +328,9 @@ func TestRatingsData_HoursToMaxRatingFromStartRatingZeroErr(t *testing.T) { func TestRatingsData_PositiveDecreaseRatingsStepsShouldErr(t *testing.T) { t.Parallel() - ratingsDataArg := createDymmyRatingsData() + ratingsDataArg := createDummyRatingsData() ratingsConfig := createDummyRatingsConfig() - ratingsConfig.MetaChain.ProposerDecreaseFactor = -0.5 + ratingsConfig.MetaChain.RatingStepsByEpoch[0].ProposerDecreaseFactor = -0.5 ratingsDataArg.Config = ratingsConfig ratingsData, err := NewRatingsData(ratingsDataArg) @@ -208,7 +339,7 @@ func TestRatingsData_PositiveDecreaseRatingsStepsShouldErr(t *testing.T) { require.True(t, strings.Contains(err.Error(), "meta")) ratingsConfig = createDummyRatingsConfig() - ratingsConfig.MetaChain.ValidatorDecreaseFactor = -0.5 + ratingsConfig.MetaChain.RatingStepsByEpoch[0].ValidatorDecreaseFactor = -0.5 ratingsDataArg.Config = ratingsConfig ratingsData, err = NewRatingsData(ratingsDataArg) @@ -217,7 +348,7 @@ func TestRatingsData_PositiveDecreaseRatingsStepsShouldErr(t *testing.T) { require.True(t, strings.Contains(err.Error(), "meta")) ratingsConfig = createDummyRatingsConfig() - ratingsConfig.ShardChain.ProposerDecreaseFactor = -0.5 + ratingsConfig.ShardChain.RatingStepsByEpoch[0].ProposerDecreaseFactor = -0.5 ratingsDataArg.Config = ratingsConfig ratingsData, err = NewRatingsData(ratingsDataArg) @@ -226,7 +357,7 @@ func TestRatingsData_PositiveDecreaseRatingsStepsShouldErr(t *testing.T) { require.True(t, strings.Contains(err.Error(), "shard")) ratingsConfig = createDummyRatingsConfig() - ratingsConfig.ShardChain.ValidatorDecreaseFactor = -0.5 + ratingsConfig.ShardChain.RatingStepsByEpoch[0].ValidatorDecreaseFactor = -0.5 ratingsDataArg.Config = ratingsConfig ratingsData, err = NewRatingsData(ratingsDataArg) @@ -238,9 +369,9 @@ func TestRatingsData_PositiveDecreaseRatingsStepsShouldErr(t *testing.T) { func TestRatingsData_UnderflowErr(t *testing.T) { t.Parallel() - ratingsDataArg := createDymmyRatingsData() + ratingsDataArg := createDummyRatingsData() ratingsConfig := createDummyRatingsConfig() - ratingsConfig.MetaChain.ProposerDecreaseFactor = math.MinInt32 + ratingsConfig.MetaChain.RatingStepsByEpoch[0].ProposerDecreaseFactor = math.MinInt32 ratingsDataArg.Config = ratingsConfig ratingsData, err := NewRatingsData(ratingsDataArg) @@ -248,9 +379,9 @@ func TestRatingsData_UnderflowErr(t *testing.T) { require.True(t, errors.Is(err, process.ErrOverflow)) require.True(t, strings.Contains(err.Error(), "proposerDecrease")) - ratingsDataArg = createDymmyRatingsData() + ratingsDataArg = createDummyRatingsData() ratingsConfig = createDummyRatingsConfig() - ratingsConfig.MetaChain.ValidatorDecreaseFactor = math.MinInt32 + ratingsConfig.MetaChain.RatingStepsByEpoch[0].ValidatorDecreaseFactor = math.MinInt32 ratingsDataArg.Config = ratingsConfig ratingsData, err = NewRatingsData(ratingsDataArg) @@ -258,9 +389,9 @@ func TestRatingsData_UnderflowErr(t *testing.T) { require.True(t, errors.Is(err, process.ErrOverflow)) require.True(t, strings.Contains(err.Error(), "validatorDecrease")) - ratingsDataArg = createDymmyRatingsData() + ratingsDataArg = createDummyRatingsData() ratingsConfig = createDummyRatingsConfig() - ratingsConfig.ShardChain.ProposerDecreaseFactor = math.MinInt32 + ratingsConfig.ShardChain.RatingStepsByEpoch[0].ProposerDecreaseFactor = math.MinInt32 ratingsDataArg.Config = ratingsConfig ratingsData, err = NewRatingsData(ratingsDataArg) @@ -268,9 +399,9 @@ func TestRatingsData_UnderflowErr(t *testing.T) { require.True(t, errors.Is(err, process.ErrOverflow)) require.True(t, strings.Contains(err.Error(), "proposerDecrease")) - ratingsDataArg = createDymmyRatingsData() + ratingsDataArg = createDummyRatingsData() ratingsConfig = createDummyRatingsConfig() - ratingsConfig.ShardChain.ValidatorDecreaseFactor = math.MinInt32 + ratingsConfig.ShardChain.RatingStepsByEpoch[0].ValidatorDecreaseFactor = math.MinInt32 ratingsDataArg.Config = ratingsConfig ratingsData, err = NewRatingsData(ratingsDataArg) @@ -279,51 +410,294 @@ func TestRatingsData_UnderflowErr(t *testing.T) { require.True(t, strings.Contains(err.Error(), "validatorDecrease")) } +func TestRatingsData_EpochConfirmed(t *testing.T) { + t.Parallel() + + // Activation epochs for this test: + // 0 -> new chain params but same values as epoch 0, new ratingSteps for shard, new ratingSteps for meta + // 4 -> same chain params as epoch 0, new ratingSteps for shard, same meta ratingSteps as epoch 0 + // 5 -> new chain params, same shard ratingSteps as epoch 4, same meta ratingSteps as epoch 0 + // 7 -> same chain params as epoch 5, new shard ratingSteps, new meta ratingSteps + // 10 -> new chain params but same values as epoch 5, same shard ratingSteps as epoch 7, same meta ratingSteps as epoch 7 + // 15 -> new chain params, new shard ratingSteps, new meta ratingSteps + chainParams := make([]config.ChainParametersByEpochConfig, 0) + for i := uint32(0); i <= 15; i += 5 { + newChainParams := config.ChainParametersByEpochConfig{ + RoundDuration: 4000, + Hysteresis: 0.2, + EnableEpoch: i, + ShardConsensusGroupSize: shardConsensusSize, + ShardMinNumNodes: shardMinNodes, + MetachainConsensusGroupSize: metaConsensusSize, + MetachainMinNumNodes: metaMinNodes, + Adaptivity: false, + } + // change consensus size for shard after epoch 5 + if i >= 5 { + newChainParams.ShardConsensusGroupSize = shardConsensusSize + i + } + + chainParams = append(chainParams, newChainParams) + } + expectedChainParamsIdx := 0 + chainParamsHandler := &chainParameters.ChainParametersHandlerStub{ + AllChainParametersCalled: func() []config.ChainParametersByEpochConfig { + return chainParams + }, + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return chainParams[expectedChainParamsIdx] + }, + } + ratingsDataArg := createDummyRatingsData() + ratingsDataArg.Config = createDummyRatingsConfig() + ratingsDataArg.Config.ShardChain.RatingStepsByEpoch = []config.RatingSteps{ + { + HoursToMaxRatingFromStartRating: 1, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: 1.1, + EnableEpoch: 0, + }, + { + HoursToMaxRatingFromStartRating: 2, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: 1.5, + EnableEpoch: 4, + }, + { + HoursToMaxRatingFromStartRating: 2, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: 1.7, + EnableEpoch: 7, + }, + { + HoursToMaxRatingFromStartRating: 1, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: 1.8, + EnableEpoch: 15, + }, + } + ratingsDataArg.Config.MetaChain.RatingStepsByEpoch = []config.RatingSteps{ + { + HoursToMaxRatingFromStartRating: 2, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: 1.5, + EnableEpoch: 0, + }, + { + HoursToMaxRatingFromStartRating: 2, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: 1.7, + EnableEpoch: 7, + }, + { + HoursToMaxRatingFromStartRating: 1, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: 1.9, + EnableEpoch: 15, + }, + } + ratingsDataArg.ChainParametersHolder = chainParamsHandler + rd, err := NewRatingsData(ratingsDataArg) + require.NoError(t, err) + require.NotNil(t, rd) + + cntSetInt64ValueHandler := 0 + cntSetStringValueHandler := 0 + handler := &statusHandler.AppStatusHandlerStub{ + SetUInt64ValueHandler: func(key string, value uint64) { + cntSetInt64ValueHandler++ + }, + SetStringValueHandler: func(key string, value string) { + cntSetStringValueHandler++ + }, + } + _ = rd.SetStatusHandler(handler) + + // ensure that the configs are stored in descending order + currentConfig := rd.ratingsStepsConfig[0] + for i := 1; i < len(rd.ratingsStepsConfig); i++ { + require.Less(t, rd.ratingsStepsConfig[i].enableEpoch, currentConfig.enableEpoch) + currentConfig = rd.ratingsStepsConfig[i] + } + + // check epoch 0 + require.Equal(t, uint32(0), rd.currentRatingsStepData.enableEpoch) + expectedConsecutiveMissedBlocksPenaltyShardEpoch0 := ratingsDataArg.Config.ShardChain.RatingStepsByEpoch[0].ConsecutiveMissedBlocksPenalty + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyShardEpoch0, rd.currentRatingsStepData.shardRatingsStepData.ConsecutiveMissedBlocksPenalty()) + expectedConsecutiveMissedBlocksPenaltyMetaEpoch0 := ratingsDataArg.Config.MetaChain.RatingStepsByEpoch[0].ConsecutiveMissedBlocksPenalty + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyMetaEpoch0, rd.currentRatingsStepData.metaRatingsStepData.ConsecutiveMissedBlocksPenalty()) + + // check epoch 1, nothing changed, same as before + rd.EpochConfirmed(1, 0) + require.Equal(t, uint32(0), rd.currentRatingsStepData.enableEpoch) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyShardEpoch0, rd.currentRatingsStepData.shardRatingsStepData.ConsecutiveMissedBlocksPenalty()) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyMetaEpoch0, rd.currentRatingsStepData.metaRatingsStepData.ConsecutiveMissedBlocksPenalty()) + + // check epoch 4, shard changed, chain params changed + rd.EpochConfirmed(4, 0) + require.Equal(t, uint32(4), rd.currentRatingsStepData.enableEpoch) + expectedConsecutiveMissedBlocksPenaltyShardEpoch4 := ratingsDataArg.Config.ShardChain.RatingStepsByEpoch[1].ConsecutiveMissedBlocksPenalty + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyShardEpoch4, rd.currentRatingsStepData.shardRatingsStepData.ConsecutiveMissedBlocksPenalty()) + expectedConsecutiveMissedBlocksPenaltyMetaEpoch4 := ratingsDataArg.Config.MetaChain.RatingStepsByEpoch[0].ConsecutiveMissedBlocksPenalty + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyMetaEpoch4, rd.currentRatingsStepData.metaRatingsStepData.ConsecutiveMissedBlocksPenalty()) + + // check epoch 5, nothing changed, same as before, but we have new chain params defined for this epoch + rd.EpochConfirmed(5, 0) + require.Equal(t, uint32(5), rd.currentRatingsStepData.enableEpoch) + expectedChainParamsIdx = 1 // epoch 5 + require.Equal(t, uint32(shardConsensusSize+5), rd.chainParametersHandler.CurrentChainParameters().ShardConsensusGroupSize) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyShardEpoch4, rd.currentRatingsStepData.shardRatingsStepData.ConsecutiveMissedBlocksPenalty()) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyMetaEpoch4, rd.currentRatingsStepData.metaRatingsStepData.ConsecutiveMissedBlocksPenalty()) + + // check epoch 6, nothing changed, same as before + rd.EpochConfirmed(6, 0) + require.Equal(t, uint32(5), rd.currentRatingsStepData.enableEpoch) + require.Equal(t, uint32(shardConsensusSize+5), rd.chainParametersHandler.CurrentChainParameters().ShardConsensusGroupSize) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyShardEpoch4, rd.currentRatingsStepData.shardRatingsStepData.ConsecutiveMissedBlocksPenalty()) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyMetaEpoch4, rd.currentRatingsStepData.metaRatingsStepData.ConsecutiveMissedBlocksPenalty()) + + // check epoch 7, shard changed, meta changed, same chain params + rd.EpochConfirmed(7, 0) + require.Equal(t, uint32(7), rd.currentRatingsStepData.enableEpoch) + require.Equal(t, uint32(shardConsensusSize+5), rd.chainParametersHandler.CurrentChainParameters().ShardConsensusGroupSize) + expectedConsecutiveMissedBlocksPenaltyShardEpoch7 := ratingsDataArg.Config.ShardChain.RatingStepsByEpoch[2].ConsecutiveMissedBlocksPenalty + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyShardEpoch7, rd.currentRatingsStepData.shardRatingsStepData.ConsecutiveMissedBlocksPenalty()) + expectedConsecutiveMissedBlocksPenaltyMetaEpoch7 := ratingsDataArg.Config.MetaChain.RatingStepsByEpoch[1].ConsecutiveMissedBlocksPenalty + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyMetaEpoch7, rd.currentRatingsStepData.metaRatingsStepData.ConsecutiveMissedBlocksPenalty()) + + // check epoch 9, nothing changed, same as before + rd.EpochConfirmed(9, 0) + require.Equal(t, uint32(7), rd.currentRatingsStepData.enableEpoch) + require.Equal(t, uint32(shardConsensusSize+5), rd.chainParametersHandler.CurrentChainParameters().ShardConsensusGroupSize) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyShardEpoch7, rd.currentRatingsStepData.shardRatingsStepData.ConsecutiveMissedBlocksPenalty()) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyMetaEpoch7, rd.currentRatingsStepData.metaRatingsStepData.ConsecutiveMissedBlocksPenalty()) + + // check epoch 10, nothing changed, same as before, but we have new chain params defined for this epoch + rd.EpochConfirmed(10, 0) + require.Equal(t, uint32(10), rd.currentRatingsStepData.enableEpoch) + expectedChainParamsIdx = 2 // epoch 10 + require.Equal(t, uint32(shardConsensusSize+10), rd.chainParametersHandler.CurrentChainParameters().ShardConsensusGroupSize) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyShardEpoch7, rd.currentRatingsStepData.shardRatingsStepData.ConsecutiveMissedBlocksPenalty()) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyMetaEpoch7, rd.currentRatingsStepData.metaRatingsStepData.ConsecutiveMissedBlocksPenalty()) + + // check epoch 11, nothing changed, same as before + rd.EpochConfirmed(11, 0) + require.Equal(t, uint32(10), rd.currentRatingsStepData.enableEpoch) + require.Equal(t, uint32(shardConsensusSize+10), rd.chainParametersHandler.CurrentChainParameters().ShardConsensusGroupSize) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyShardEpoch7, rd.currentRatingsStepData.shardRatingsStepData.ConsecutiveMissedBlocksPenalty()) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyMetaEpoch7, rd.currentRatingsStepData.metaRatingsStepData.ConsecutiveMissedBlocksPenalty()) + + // check epoch 15, shard changed, meta changed, chain params changed + rd.EpochConfirmed(15, 0) + require.Equal(t, uint32(15), rd.currentRatingsStepData.enableEpoch) + expectedChainParamsIdx = 3 // epoch 15 + require.Equal(t, uint32(shardConsensusSize+15), rd.chainParametersHandler.CurrentChainParameters().ShardConsensusGroupSize) + expectedConsecutiveMissedBlocksPenaltyShardEpoch15 := ratingsDataArg.Config.ShardChain.RatingStepsByEpoch[3].ConsecutiveMissedBlocksPenalty + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyShardEpoch15, rd.currentRatingsStepData.shardRatingsStepData.ConsecutiveMissedBlocksPenalty()) + expectedConsecutiveMissedBlocksPenaltyMetaEpoch15 := ratingsDataArg.Config.MetaChain.RatingStepsByEpoch[2].ConsecutiveMissedBlocksPenalty + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyMetaEpoch15, rd.currentRatingsStepData.metaRatingsStepData.ConsecutiveMissedBlocksPenalty()) + + // check epoch 429, nothing changed, same as before + rd.EpochConfirmed(429, 0) + require.Equal(t, uint32(15), rd.currentRatingsStepData.enableEpoch) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyShardEpoch15, rd.currentRatingsStepData.shardRatingsStepData.ConsecutiveMissedBlocksPenalty()) + require.Equal(t, expectedConsecutiveMissedBlocksPenaltyMetaEpoch15, rd.currentRatingsStepData.metaRatingsStepData.ConsecutiveMissedBlocksPenalty()) + + expectedNumberOfConfigChanges := 5 + require.Equal(t, expectedNumberOfConfigChanges*2, cntSetInt64ValueHandler) // for each epoch confirmed should be called twice, shard + meta + require.Equal(t, expectedNumberOfConfigChanges*8, cntSetStringValueHandler) // for each epoch confirmed should be called 8 times, 4 for shard, 4 for meta +} + func TestRatingsData_OverflowErr(t *testing.T) { t.Parallel() - ratingsDataArg := createDymmyRatingsData() + getBaseChainParams := func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + RoundDuration: 4000, + Hysteresis: 0.2, + EnableEpoch: 0, + ShardConsensusGroupSize: 5, + ShardMinNumNodes: 7, + MetachainConsensusGroupSize: 7, + MetachainMinNumNodes: 7, + Adaptivity: false, + } + } + getChainParametersHandler := func(cfg config.ChainParametersByEpochConfig) *chainParameters.ChainParametersHandlerStub { + return &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return cfg + }, + AllChainParametersCalled: func() []config.ChainParametersByEpochConfig { + return []config.ChainParametersByEpochConfig{cfg} + }, + } + } + + ratingsDataArg := createDummyRatingsData() ratingsConfig := createDummyRatingsConfig() ratingsDataArg.Config = ratingsConfig - ratingsDataArg.RoundDurationMiliseconds = 3600 * 1000 - ratingsDataArg.MetaMinNodes = math.MaxUint32 + chainParams := getBaseChainParams() + chainParams.RoundDuration = 3600 * 1000 + chainParams.MetachainMinNumNodes = math.MaxUint32 + ratingsDataArg.ChainParametersHolder = getChainParametersHandler(chainParams) ratingsData, err := NewRatingsData(ratingsDataArg) require.Nil(t, ratingsData) require.True(t, errors.Is(err, process.ErrOverflow)) require.True(t, strings.Contains(err.Error(), "proposerIncrease")) - ratingsDataArg = createDymmyRatingsData() + ratingsDataArg = createDummyRatingsData() ratingsConfig = createDummyRatingsConfig() ratingsDataArg.Config = ratingsConfig - ratingsDataArg.RoundDurationMiliseconds = 3600 * 1000 - ratingsDataArg.MetaMinNodes = math.MaxUint32 - ratingsDataArg.MetaConsensusSize = 1 - ratingsDataArg.Config.MetaChain.ProposerValidatorImportance = float32(1) / math.MaxUint32 + chainParams = getBaseChainParams() + chainParams.RoundDuration = 3600 * 1000 + chainParams.MetachainMinNumNodes = math.MaxUint32 + chainParams.MetachainConsensusGroupSize = 1 + ratingsDataArg.ChainParametersHolder = getChainParametersHandler(chainParams) + ratingsDataArg.Config.MetaChain.RatingStepsByEpoch[0].ProposerValidatorImportance = float32(1) / math.MaxUint32 ratingsData, err = NewRatingsData(ratingsDataArg) require.Nil(t, ratingsData) require.True(t, errors.Is(err, process.ErrOverflow)) require.True(t, strings.Contains(err.Error(), "validatorIncrease")) - ratingsDataArg = createDymmyRatingsData() + ratingsDataArg = createDummyRatingsData() ratingsConfig = createDummyRatingsConfig() ratingsDataArg.Config = ratingsConfig - ratingsDataArg.RoundDurationMiliseconds = 3600 * 1000 - ratingsDataArg.ShardMinNodes = math.MaxUint32 + chainParams = getBaseChainParams() + chainParams.RoundDuration = 3600 * 1000 + chainParams.ShardMinNumNodes = math.MaxUint32 + ratingsDataArg.ChainParametersHolder = getChainParametersHandler(chainParams) ratingsData, err = NewRatingsData(ratingsDataArg) require.Nil(t, ratingsData) require.True(t, errors.Is(err, process.ErrOverflow)) require.True(t, strings.Contains(err.Error(), "proposerIncrease")) - ratingsDataArg = createDymmyRatingsData() + ratingsDataArg = createDummyRatingsData() ratingsConfig = createDummyRatingsConfig() ratingsDataArg.Config = ratingsConfig - ratingsDataArg.RoundDurationMiliseconds = 3600 * 1000 - ratingsDataArg.ShardMinNodes = math.MaxUint32 - ratingsDataArg.ShardConsensusSize = 1 - ratingsDataArg.Config.ShardChain.ProposerValidatorImportance = float32(1) / math.MaxUint32 + chainParams = getBaseChainParams() + chainParams.RoundDuration = 3600 * 1000 + chainParams.ShardMinNumNodes = math.MaxUint32 + chainParams.ShardConsensusGroupSize = 1 + ratingsDataArg.ChainParametersHolder = getChainParametersHandler(chainParams) + ratingsDataArg.Config.ShardChain.RatingStepsByEpoch[0].ProposerValidatorImportance = float32(1) / math.MaxUint32 ratingsData, err = NewRatingsData(ratingsDataArg) require.Nil(t, ratingsData) @@ -334,21 +708,21 @@ func TestRatingsData_OverflowErr(t *testing.T) { func TestRatingsData_IncreaseLowerThanZeroErr(t *testing.T) { t.Parallel() - ratingsDataArg := createDymmyRatingsData() + ratingsDataArg := createDummyRatingsData() ratingsConfig := createDummyRatingsConfig() ratingsDataArg.Config = ratingsConfig - ratingsDataArg.Config.MetaChain.HoursToMaxRatingFromStartRating = math.MaxUint32 + ratingsDataArg.Config.MetaChain.RatingStepsByEpoch[0].HoursToMaxRatingFromStartRating = math.MaxUint32 ratingsData, err := NewRatingsData(ratingsDataArg) require.Nil(t, ratingsData) require.True(t, errors.Is(err, process.ErrIncreaseStepLowerThanOne)) require.True(t, strings.Contains(err.Error(), "proposerIncrease")) - ratingsDataArg = createDymmyRatingsData() + ratingsDataArg = createDummyRatingsData() ratingsConfig = createDummyRatingsConfig() ratingsDataArg.Config = ratingsConfig - ratingsDataArg.Config.MetaChain.HoursToMaxRatingFromStartRating = 2 - ratingsDataArg.Config.MetaChain.ProposerValidatorImportance = math.MaxUint32 + ratingsDataArg.Config.MetaChain.RatingStepsByEpoch[0].HoursToMaxRatingFromStartRating = 2 + ratingsDataArg.Config.MetaChain.RatingStepsByEpoch[0].ProposerValidatorImportance = math.MaxUint32 ratingsData, err = NewRatingsData(ratingsDataArg) require.Nil(t, ratingsData) @@ -359,7 +733,7 @@ func TestRatingsData_IncreaseLowerThanZeroErr(t *testing.T) { func TestRatingsData_RatingsCorrectValues(t *testing.T) { t.Parallel() - ratingsDataArg := createDymmyRatingsData() + ratingsDataArg := createDummyRatingsData() minRating := uint32(1) maxRating := uint32(10000) startRating := uint32(4000) @@ -373,15 +747,15 @@ func TestRatingsData_RatingsCorrectValues(t *testing.T) { ratingsConfig.General.MinRating = minRating ratingsConfig.General.MaxRating = maxRating ratingsConfig.General.StartRating = startRating - ratingsConfig.MetaChain.HoursToMaxRatingFromStartRating = hoursToMaxRatingFromStartRating - ratingsConfig.ShardChain.HoursToMaxRatingFromStartRating = hoursToMaxRatingFromStartRating + ratingsConfig.MetaChain.RatingStepsByEpoch[0].HoursToMaxRatingFromStartRating = hoursToMaxRatingFromStartRating + ratingsConfig.ShardChain.RatingStepsByEpoch[0].HoursToMaxRatingFromStartRating = hoursToMaxRatingFromStartRating ratingsConfig.General.SignedBlocksThreshold = signedBlocksThreshold - ratingsConfig.ShardChain.ConsecutiveMissedBlocksPenalty = shardConsecutivePenalty - ratingsConfig.ShardChain.ProposerDecreaseFactor = decreaseFactor - ratingsConfig.ShardChain.ValidatorDecreaseFactor = decreaseFactor - ratingsConfig.MetaChain.ConsecutiveMissedBlocksPenalty = metaConsecutivePenalty - ratingsConfig.MetaChain.ProposerDecreaseFactor = decreaseFactor - ratingsConfig.MetaChain.ValidatorDecreaseFactor = decreaseFactor + ratingsConfig.ShardChain.RatingStepsByEpoch[0].ConsecutiveMissedBlocksPenalty = shardConsecutivePenalty + ratingsConfig.ShardChain.RatingStepsByEpoch[0].ProposerDecreaseFactor = decreaseFactor + ratingsConfig.ShardChain.RatingStepsByEpoch[0].ValidatorDecreaseFactor = decreaseFactor + ratingsConfig.MetaChain.RatingStepsByEpoch[0].ConsecutiveMissedBlocksPenalty = metaConsecutivePenalty + ratingsConfig.MetaChain.RatingStepsByEpoch[0].ProposerDecreaseFactor = decreaseFactor + ratingsConfig.MetaChain.RatingStepsByEpoch[0].ValidatorDecreaseFactor = decreaseFactor selectionChances := []*config.SelectionChance{ {MaxThreshold: 0, ChancePercent: 1}, @@ -416,3 +790,18 @@ func TestRatingsData_RatingsCorrectValues(t *testing.T) { assert.Equal(t, selectionChances[i].ChancePercent, ratingsData.SelectionChances()[i].GetChancePercent()) } } + +func TestRatingsData_SetStatusHandler(t *testing.T) { + t.Parallel() + + ratingsDataArg := createDummyRatingsData() + ratingsDataArg.Config = createDummyRatingsConfig() + ratingsData, _ := NewRatingsData(ratingsDataArg) + require.NotNil(t, ratingsData) + + err := ratingsData.SetStatusHandler(nil) + require.Equal(t, process.ErrNilAppStatusHandler, err) + + err = ratingsData.SetStatusHandler(&statusHandler.AppStatusHandlerStub{}) + require.NoError(t, err) +} diff --git a/process/smartContract/hooks/blockChainHook_test.go b/process/smartContract/hooks/blockChainHook_test.go index a9724f55831..01142134c16 100644 --- a/process/smartContract/hooks/blockChainHook_test.go +++ b/process/smartContract/hooks/blockChainHook_test.go @@ -16,6 +16,13 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/esdt" "github.com/multiversx/mx-chain-core-go/data/transaction" + vmcommon "github.com/multiversx/mx-chain-vm-common-go" + vmcommonBuiltInFunctions "github.com/multiversx/mx-chain-vm-common-go/builtInFunctions" + "github.com/multiversx/mx-chain-vm-common-go/parsers" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -27,6 +34,7 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" @@ -34,12 +42,6 @@ import ( stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/testscommon/trie" - vmcommon "github.com/multiversx/mx-chain-vm-common-go" - vmcommonBuiltInFunctions "github.com/multiversx/mx-chain-vm-common-go/builtInFunctions" - "github.com/multiversx/mx-chain-vm-common-go/parsers" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockBlockChainHookArgs() hooks.ArgBlockChainHook { @@ -1581,7 +1583,7 @@ func TestBlockChainHookImpl_SaveCompiledCode(t *testing.T) { args := createMockBlockChainHookArgs() wasCodeSavedInPool := &atomic.Flag{} - args.CompiledSCPool = &testscommon.CacherStub{ + args.CompiledSCPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { require.Equal(t, codeHash, key) return code, true @@ -1603,7 +1605,7 @@ func TestBlockChainHookImpl_SaveCompiledCode(t *testing.T) { args.NilCompiledSCStore = true wasCodeSavedInPool := &atomic.Flag{} - args.CompiledSCPool = &testscommon.CacherStub{ + args.CompiledSCPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { require.Equal(t, codeHash, key) return struct{}{}, true @@ -1636,7 +1638,7 @@ func TestBlockChainHookImpl_SaveCompiledCode(t *testing.T) { }, } wasCodeSavedInPool := &atomic.Flag{} - args.CompiledSCPool = &testscommon.CacherStub{ + args.CompiledSCPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { require.Equal(t, codeHash, key) return nil, false @@ -1673,7 +1675,7 @@ func TestBlockChainHookImpl_SaveCompiledCode(t *testing.T) { }, } args.NilCompiledSCStore = false - args.CompiledSCPool = &testscommon.CacherStub{ + args.CompiledSCPool = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { require.Equal(t, codeHash, key) return nil, false @@ -2536,7 +2538,7 @@ func TestBlockChainHookImpl_ClearCompiledCodes(t *testing.T) { args.EnableEpochs.IsPayableBySCEnableEpoch = 11 clearCalled := 0 - args.CompiledSCPool = &testscommon.CacherStub{ClearCalled: func() { + args.CompiledSCPool = &cache.CacherStub{ClearCalled: func() { clearCalled++ }} diff --git a/process/smartContract/processorV2/testScProcessor.go b/process/smartContract/processorV2/testScProcessor.go index 0e8e643605f..52a63c6c308 100644 --- a/process/smartContract/processorV2/testScProcessor.go +++ b/process/smartContract/processorV2/testScProcessor.go @@ -2,6 +2,7 @@ package processorV2 import ( "encoding/hex" + "errors" "fmt" "strings" @@ -79,7 +80,7 @@ func (tsp *TestScProcessor) GetCompositeTestError() error { func wrapErrorIfNotContains(originalError error, msg string) error { if originalError == nil { - return fmt.Errorf(msg) + return errors.New(msg) } alreadyContainsMessage := strings.Contains(originalError.Error(), msg) diff --git a/process/smartContract/scQueryService_test.go b/process/smartContract/scQueryService_test.go index c2ff2035152..a5ec1d2fed5 100644 --- a/process/smartContract/scQueryService_test.go +++ b/process/smartContract/scQueryService_test.go @@ -1186,7 +1186,7 @@ func TestSCQueryService_EpochStartBlockHdrConcurrent(t *testing.T) { BlockHash: []byte(fmt.Sprintf("hash-%d", idx)), } - _, _, err = qs.ExecuteQuery(&query) + _, _, err := qs.ExecuteQuery(&query) require.NoError(t, err) }(i) } diff --git a/process/smartContract/testScProcessor.go b/process/smartContract/testScProcessor.go index a13419ab621..d602619c61e 100644 --- a/process/smartContract/testScProcessor.go +++ b/process/smartContract/testScProcessor.go @@ -2,6 +2,7 @@ package smartContract import ( "encoding/hex" + "errors" "fmt" "strings" @@ -83,7 +84,7 @@ func (tsp *TestScProcessor) GetCompositeTestError() error { func wrapErrorIfNotContains(originalError error, msg string) error { if originalError == nil { - return fmt.Errorf(msg) + return errors.New(msg) } alreadyContainsMessage := strings.Contains(originalError.Error(), msg) diff --git a/process/sync/argBootstrapper.go b/process/sync/argBootstrapper.go index ec3f64a58d8..587ecedd258 100644 --- a/process/sync/argBootstrapper.go +++ b/process/sync/argBootstrapper.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dblookupext" @@ -48,6 +49,7 @@ type ArgBaseBootstrapper struct { ScheduledTxsExecutionHandler process.ScheduledTxsExecutionHandler ProcessWaitTime time.Duration RepopulateTokensSupplies bool + EnableEpochsHandler common.EnableEpochsHandler } // ArgShardBootstrapper holds all dependencies required by the bootstrap data factory in order to create diff --git a/process/sync/baseForkDetector.go b/process/sync/baseForkDetector.go index db5a601524a..04e3938d6ff 100644 --- a/process/sync/baseForkDetector.go +++ b/process/sync/baseForkDetector.go @@ -7,16 +7,19 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/process" ) type headerInfo struct { - epoch uint32 - nonce uint64 - round uint64 - hash []byte - state process.BlockHeaderState + epoch uint32 + nonce uint64 + round uint64 + hash []byte + state process.BlockHeaderState + hasProof bool } type checkpointInfo struct { @@ -43,14 +46,16 @@ type baseForkDetector struct { fork forkInfo mutFork sync.RWMutex - blackListHandler process.TimeCacher - genesisTime int64 - blockTracker process.BlockTracker - forkDetector forkDetector - genesisNonce uint64 - genesisRound uint64 - maxForkHeaderEpoch uint32 - genesisEpoch uint32 + blackListHandler process.TimeCacher + genesisTime int64 + blockTracker process.BlockTracker + forkDetector forkDetector + genesisNonce uint64 + genesisRound uint64 + maxForkHeaderEpoch uint32 + genesisEpoch uint32 + enableEpochsHandler common.EnableEpochsHandler + proofsPool process.ProofsPool } // SetRollBackNonce sets the nonce where the chain should roll back @@ -102,7 +107,7 @@ func (bfd *baseForkDetector) checkBlockBasicValidity( roundDif := int64(header.GetRound()) - int64(bfd.finalCheckpoint().round) nonceDif := int64(header.GetNonce()) - int64(bfd.finalCheckpoint().nonce) - //TODO: Analyze if the acceptance of some headers which came for the next round could generate some attack vectors + // TODO: Analyze if the acceptance of some headers which came for the next round could generate some attack vectors nextRound := bfd.roundHandler.Index() + 1 genesisTimeFromHeader := bfd.computeGenesisTimeFromHeader(header) @@ -111,7 +116,7 @@ func (bfd *baseForkDetector) checkBlockBasicValidity( process.AddHeaderToBlackList(bfd.blackListHandler, headerHash) return process.ErrHeaderIsBlackListed } - //TODO: This check could be removed when this protection mechanism would be implemented on interceptors side + // TODO: This check could be removed when this protection mechanism would be implemented on interceptors side if genesisTimeFromHeader != bfd.genesisTime { process.AddHeaderToBlackList(bfd.blackListHandler, headerHash) return ErrGenesisTimeMissmatch @@ -197,11 +202,17 @@ func (bfd *baseForkDetector) computeProbableHighestNonce() uint64 { probableHighestNonce := bfd.finalCheckpoint().nonce bfd.mutHeaders.RLock() - for nonce := range bfd.headers { + for nonce, headers := range bfd.headers { if nonce <= probableHighestNonce { continue } - probableHighestNonce = nonce + + for _, hInfo := range headers { + if hInfo.hasProof { + probableHighestNonce = nonce + break + } + } } bfd.mutHeaders.RUnlock() @@ -286,8 +297,10 @@ func (bfd *baseForkDetector) append(hdrInfo *headerInfo) bool { return true } + bfd.adjustHeadersWithInfo(hdrInfo) + for _, hdrInfoStored := range hdrInfos { - if bytes.Equal(hdrInfoStored.hash, hdrInfo.hash) && hdrInfoStored.state == hdrInfo.state { + if bytes.Equal(hdrInfoStored.hash, hdrInfo.hash) && hdrInfoStored.state == hdrInfo.state && hdrInfoStored.hasProof == hdrInfo.hasProof { return false } } @@ -296,6 +309,23 @@ func (bfd *baseForkDetector) append(hdrInfo *headerInfo) bool { return true } +func (bfd *baseForkDetector) adjustHeadersWithInfo(hInfo *headerInfo) { + if !bfd.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, hInfo.epoch) { + return + } + + if !hInfo.hasProof { + return + } + + hdrInfos := bfd.headers[hInfo.nonce] + for i := range hdrInfos { + if bytes.Equal(hdrInfos[i].hash, hInfo.hash) { + hdrInfos[i].hasProof = true + } + } +} + // GetHighestFinalBlockNonce gets the highest nonce of the block which is final, and it can not be reverted anymore func (bfd *baseForkDetector) GetHighestFinalBlockNonce() uint64 { return bfd.finalCheckpoint().nonce @@ -336,6 +366,16 @@ func (bfd *baseForkDetector) addCheckpoint(checkpoint *checkpointInfo) { bfd.mutFork.Unlock() } +// AddCheckpoint adds a new checkpoint in the list +func (bfd *baseForkDetector) AddCheckpoint(nonce uint64, round uint64, hash []byte) { + checkpoint := &checkpointInfo{ + nonce: nonce, + round: round, + hash: hash, + } + bfd.addCheckpoint(checkpoint) +} + func (bfd *baseForkDetector) lastCheckpoint() *checkpointInfo { bfd.mutFork.RLock() lastIndex := len(bfd.fork.checkpoint) - 1 @@ -682,6 +722,39 @@ func (bfd *baseForkDetector) addHeader( return nil } +// ReceivedProof is called when a proof is received +func (bfd *baseForkDetector) ReceivedProof(proof data.HeaderProofHandler) { + bfd.processReceivedProof(proof) +} + +func (bfd *baseForkDetector) processReceivedProof(proof data.HeaderProofHandler) { + bfd.setHighestNonceReceived(proof.GetHeaderNonce()) + + hInfo := &headerInfo{ + epoch: proof.GetHeaderEpoch(), + nonce: proof.GetHeaderNonce(), + round: proof.GetHeaderRound(), + hash: proof.GetHeaderHash(), + state: process.BHReceived, + hasProof: true, + } + + _ = bfd.append(hInfo) + + probableHighestNonce := bfd.computeProbableHighestNonce() + bfd.setProbableHighestNonce(probableHighestNonce) + + log.Trace("forkDetector.processReceivedProof", + "round", hInfo.round, + "nonce", hInfo.nonce, + "hash", hInfo.hash, + "state", hInfo.state, + "probable highest nonce", bfd.probableHighestNonce(), + "last checkpoint nonce", bfd.lastCheckpoint().nonce, + "final checkpoint nonce", bfd.finalCheckpoint().nonce, + "has proof", hInfo.hasProof) +} + func (bfd *baseForkDetector) processReceivedBlock( header data.HeaderHandler, headerHash []byte, @@ -690,25 +763,34 @@ func (bfd *baseForkDetector) processReceivedBlock( selfNotarizedHeadersHashes [][]byte, doJobOnBHProcessed func(data.HeaderHandler, []byte, []data.HeaderHandler, [][]byte), ) { + hasProof := true // old blocks have consensus proof on them + if common.IsProofsFlagEnabledForHeader(bfd.enableEpochsHandler, header) { + hasProof = bfd.proofsPool.HasProof(header.GetShardID(), headerHash) + } bfd.setHighestNonceReceived(header.GetNonce()) - if state == process.BHProposed { + if state == process.BHProposed || !hasProof { + log.Trace("forkDetector.processReceivedBlock: block is proposed or has no proof", "state", state, "has proof", hasProof) return } isHeaderReceivedTooLate := bfd.isHeaderReceivedTooLate(header, state, process.BlockFinality) if isHeaderReceivedTooLate { + log.Trace("forkDetector.processReceivedBlock: block is received too late", "initial state", state) state = process.BHReceivedTooLate } - appended := bfd.append(&headerInfo{ - epoch: header.GetEpoch(), - nonce: header.GetNonce(), - round: header.GetRound(), - hash: headerHash, - state: state, - }) - if !appended { + hInfo := &headerInfo{ + epoch: header.GetEpoch(), + nonce: header.GetNonce(), + round: header.GetRound(), + hash: headerHash, + state: state, + hasProof: hasProof, + } + + if !bfd.append(hInfo) { + log.Trace("forkDetector.processReceivedBlock: header not appended", "nonce", hInfo.nonce, "hash", hInfo.hash) return } @@ -719,14 +801,15 @@ func (bfd *baseForkDetector) processReceivedBlock( probableHighestNonce := bfd.computeProbableHighestNonce() bfd.setProbableHighestNonce(probableHighestNonce) - log.Debug("forkDetector.AddHeader", - "round", header.GetRound(), - "nonce", header.GetNonce(), - "hash", headerHash, - "state", state, + log.Debug("forkDetector.appendHeaderInfo", + "round", hInfo.round, + "nonce", hInfo.nonce, + "hash", hInfo.hash, + "state", hInfo.state, "probable highest nonce", bfd.probableHighestNonce(), "last checkpoint nonce", bfd.lastCheckpoint().nonce, - "final checkpoint nonce", bfd.finalCheckpoint().nonce) + "final checkpoint nonce", bfd.finalCheckpoint().nonce, + "has proof", hInfo.hasProof) } // SetFinalToLastCheckpoint sets the final checkpoint to the last checkpoint added in list diff --git a/process/sync/baseForkDetector_test.go b/process/sync/baseForkDetector_test.go index 10f857bfbce..0d23431b263 100644 --- a/process/sync/baseForkDetector_test.go +++ b/process/sync/baseForkDetector_test.go @@ -1,6 +1,7 @@ package sync_test import ( + "github.com/multiversx/mx-chain-go/testscommon/processMocks" "math" "testing" "time" @@ -8,10 +9,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/sync" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/stretchr/testify/assert" ) @@ -23,6 +29,8 @@ func TestNewBasicForkDetector_ShouldErrNilRoundHandler(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Equal(t, process.ErrNilRoundHandler, err) assert.Nil(t, bfd) @@ -37,6 +45,8 @@ func TestNewBasicForkDetector_ShouldErrNilBlackListHandler(t *testing.T) { nil, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Equal(t, process.ErrNilBlackListCacher, err) assert.Nil(t, bfd) @@ -51,11 +61,45 @@ func TestNewBasicForkDetector_ShouldErrNilBlockTracker(t *testing.T) { &testscommon.TimeCacheStub{}, nil, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Equal(t, process.ErrNilBlockTracker, err) assert.Nil(t, bfd) } +func TestNewBasicForkDetector_ShouldErrNilEnableEpochsHandler(t *testing.T) { + t.Parallel() + + roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 100} + bfd, err := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + nil, + &dataRetriever.ProofsPoolMock{}, + ) + assert.Equal(t, process.ErrNilEnableEpochsHandler, err) + assert.Nil(t, bfd) +} + +func TestNewBasicForkDetector_ShouldErrNilProofsPool(t *testing.T) { + t.Parallel() + + roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 100} + bfd, err := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + nil, + ) + assert.Equal(t, process.ErrNilProofsPool, err) + assert.Nil(t, bfd) +} + func TestNewBasicForkDetector_ShouldWork(t *testing.T) { t.Parallel() @@ -65,6 +109,8 @@ func TestNewBasicForkDetector_ShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Nil(t, err) assert.NotNil(t, bfd) @@ -84,6 +130,8 @@ func TestBasicForkDetector_CheckBlockValidityShouldErrGenesisTimeMissmatch(t *te &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, genesisTime, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) err := bfd.CheckBlockValidity(&block.Header{Nonce: 1, Round: round, TimeStamp: incorrectTimeStamp}, []byte("hash")) @@ -102,6 +150,8 @@ func TestBasicForkDetector_CheckBlockValidityShouldErrLowerRoundInBlock(t *testi &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) bfd.SetFinalCheckpoint(1, 1, nil) err := bfd.CheckBlockValidity(&block.Header{PubKeysBitmap: []byte("X")}, []byte("hash")) @@ -117,6 +167,8 @@ func TestBasicForkDetector_CheckBlockValidityShouldErrLowerNonceInBlock(t *testi &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) bfd.SetFinalCheckpoint(2, 2, nil) err := bfd.CheckBlockValidity(&block.Header{Nonce: 1, Round: 3, PubKeysBitmap: []byte("X")}, []byte("hash")) @@ -132,6 +184,8 @@ func TestBasicForkDetector_CheckBlockValidityShouldErrHigherRoundInBlock(t *test &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) err := bfd.CheckBlockValidity(&block.Header{Nonce: 1, Round: 2, PubKeysBitmap: []byte("X")}, []byte("hash")) assert.Equal(t, sync.ErrHigherRoundInBlock, err) @@ -146,6 +200,8 @@ func TestBasicForkDetector_CheckBlockValidityShouldErrHigherNonceInBlock(t *test &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) err := bfd.CheckBlockValidity(&block.Header{Nonce: 2, Round: 1, PubKeysBitmap: []byte("X")}, []byte("hash")) assert.Equal(t, sync.ErrHigherNonceInBlock, err) @@ -160,6 +216,8 @@ func TestBasicForkDetector_CheckBlockValidityShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) err := bfd.CheckBlockValidity(&block.Header{Nonce: 1, Round: 1, PubKeysBitmap: []byte("X")}, []byte("hash")) assert.Nil(t, err) @@ -178,6 +236,8 @@ func TestBasicForkDetector_RemoveHeadersShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 1 @@ -209,6 +269,8 @@ func TestBasicForkDetector_CheckForkOnlyOneShardHeaderOnANonceShouldReturnFalse( &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) _ = bfd.AddHeader( &block.Header{Nonce: 0, PubKeysBitmap: []byte("X")}, @@ -237,6 +299,8 @@ func TestBasicForkDetector_CheckForkOnlyReceivedHeadersShouldReturnFalse(t *test &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) _ = bfd.AddHeader( &block.Header{Nonce: 0, PubKeysBitmap: []byte("X")}, @@ -267,6 +331,8 @@ func TestBasicForkDetector_CheckForkOnlyOneShardHeaderOnANonceReceivedAndProcess &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) _ = bfd.AddHeader( &block.Header{Nonce: 0, PubKeysBitmap: []byte("X")}, @@ -297,6 +363,8 @@ func TestBasicForkDetector_CheckForkMetaHeaderProcessedShouldReturnFalse(t *test &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) _ = bfd.AddHeader( &block.MetaBlock{Nonce: 1, Round: 3, PubKeysBitmap: []byte("X")}, @@ -325,6 +393,8 @@ func TestBasicForkDetector_CheckForkMetaHeaderProcessedShouldReturnFalseWhenLowe &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 5 _ = bfd.AddHeader( @@ -369,6 +439,8 @@ func TestBasicForkDetector_CheckForkMetaHeaderProcessedShouldReturnFalseWhenEqua &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 5 _ = bfd.AddHeader( @@ -412,6 +484,8 @@ func TestBasicForkDetector_CheckForkShardHeaderProcessedShouldReturnTrueWhenEqua &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) hdr1 := &block.Header{Nonce: 1, Round: 4, PubKeysBitmap: []byte("X")} @@ -476,6 +550,8 @@ func TestBasicForkDetector_CheckForkMetaHeaderProcessedShouldReturnTrueWhenEqual &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 5 _ = bfd.AddHeader( @@ -518,6 +594,8 @@ func TestBasicForkDetector_CheckForkShardHeaderProcessedShouldReturnTrueWhenEqua &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) hdr1 := &block.Header{Nonce: 1, Round: 4, PubKeysBitmap: []byte("X")} @@ -581,6 +659,8 @@ func TestBasicForkDetector_CheckForkShouldReturnTrue(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 4 _ = bfd.AddHeader( @@ -625,6 +705,8 @@ func TestBasicForkDetector_CheckForkShouldReturnFalseWhenForkIsOnFinalCheckpoint &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 1 _ = bfd.AddHeader( @@ -661,6 +743,8 @@ func TestBasicForkDetector_CheckForkShouldReturnFalseWhenForkIsOnHigherEpochBloc &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 2 _ = bfd.AddHeader( @@ -703,6 +787,8 @@ func TestBasicForkDetector_RemovePastHeadersShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) _ = bfd.AddHeader(hdr1, hash1, process.BHReceived, nil, nil) _ = bfd.AddHeader(hdr2, hash2, process.BHReceived, nil, nil) @@ -737,6 +823,8 @@ func TestBasicForkDetector_RemoveInvalidReceivedHeadersShouldWork(t *testing.T) &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 11 _ = bfd.AddHeader(hdr0, hash0, process.BHReceived, nil, nil) @@ -775,6 +863,8 @@ func TestBasicForkDetector_RemoveCheckpointHeaderNonceShouldResetCheckpoint(t *t &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) _ = bfd.AddHeader(hdr1, hash1, process.BHProcessed, nil, nil) @@ -794,6 +884,8 @@ func TestBasicForkDetector_GetHighestFinalBlockNonce(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) hdr1 := &block.MetaBlock{Nonce: 2, Round: 1, PubKeysBitmap: []byte("X")} @@ -821,6 +913,7 @@ func TestBasicForkDetector_GetHighestFinalBlockNonce(t *testing.T) { assert.Equal(t, uint64(3), bfd.GetHighestFinalBlockNonce()) } +// TODO: add specific tests for equivalent proofs func TestBasicForkDetector_ProbableHighestNonce(t *testing.T) { t.Parallel() @@ -830,6 +923,12 @@ func TestBasicForkDetector_ProbableHighestNonce(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag != common.AndromedaFlag + }, + }, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 11 @@ -882,7 +981,14 @@ func TestShardForkDetector_ShouldAddBlockInForkDetectorShouldWork(t *testing.T) t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 10} - sfd, _ := sync.NewShardForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + sfd, _ := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) hdr := &block.Header{Nonce: 1, Round: 1} receivedTooLate := sfd.IsHeaderReceivedTooLate(hdr, process.BHProcessed, process.BlockFinality) @@ -900,7 +1006,14 @@ func TestShardForkDetector_ShouldAddBlockInForkDetectorShouldErrLowerRoundInBloc t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 10} - sfd, _ := sync.NewShardForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + sfd, _ := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) hdr := &block.Header{Nonce: 1, Round: 1} hdr.Round = uint64(roundHandlerMock.RoundIndex - process.BlockFinality - 1) @@ -912,7 +1025,14 @@ func TestMetaForkDetector_ShouldAddBlockInForkDetectorShouldWork(t *testing.T) { t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 10} - mfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + mfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) hdr := &block.MetaBlock{Nonce: 1, Round: 1} receivedTooLate := mfd.IsHeaderReceivedTooLate(hdr, process.BHProcessed, process.BlockFinality) @@ -930,7 +1050,14 @@ func TestMetaForkDetector_ShouldAddBlockInForkDetectorShouldErrLowerRoundInBlock t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 10} - mfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + mfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) hdr := &block.MetaBlock{Nonce: 1, Round: 1} hdr.Round = uint64(roundHandlerMock.RoundIndex - process.BlockFinality - 1) @@ -942,7 +1069,14 @@ func TestShardForkDetector_AddNotarizedHeadersShouldNotChangeTheFinalCheckpoint( t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 10} - sfd, _ := sync.NewShardForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + sfd, _ := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) hdr1 := &block.Header{Nonce: 3, Round: 3} hash1 := []byte("hash1") hdr2 := &block.Header{Nonce: 4, Round: 4} @@ -988,7 +1122,14 @@ func TestBaseForkDetector_IsConsensusStuckNotSyncingShouldReturnFalse(t *testing t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{} - bfd, _ := sync.NewShardForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) bfd.SetProbableHighestNonce(1) @@ -1004,6 +1145,8 @@ func TestBaseForkDetector_IsConsensusStuckNoncesDifferencesNotEnoughShouldReturn &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 10 @@ -1019,6 +1162,8 @@ func TestBaseForkDetector_IsConsensusStuckNotInProperRoundShouldReturnFalse(t *t &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 11 @@ -1034,6 +1179,8 @@ func TestBaseForkDetector_IsConsensusStuckShouldReturnTrue(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) // last checkpoint will be (round = 0 , nonce = 0) @@ -1060,6 +1207,8 @@ func TestBaseForkDetector_ComputeTimeDuration(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, genesisTime, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) hdr1 := &block.Header{Nonce: 1, Round: hdrRound, PubKeysBitmap: []byte("X"), TimeStamp: hdrTimeStamp} @@ -1073,7 +1222,14 @@ func TestShardForkDetector_RemoveHeaderShouldComputeFinalCheckpoint(t *testing.T t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 10} - sfd, _ := sync.NewShardForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + sfd, _ := sync.NewShardForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) hdr1 := &block.Header{Nonce: 3, Round: 3} hash1 := []byte("hash1") hdr2 := &block.Header{Nonce: 4, Round: 4} @@ -1114,6 +1270,8 @@ func TestBasicForkDetector_CheckForkMetaHeaderProcessedShouldWorkOnEqualRoundWit &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 5 _ = bfd.AddHeader( @@ -1162,6 +1320,8 @@ func TestBasicForkDetector_SetFinalToLastCheckpointShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) roundHandlerMock.RoundIndex = 1000 @@ -1180,3 +1340,212 @@ func TestBasicForkDetector_SetFinalToLastCheckpointShouldWork(t *testing.T) { assert.Equal(t, uint64(900), bfd.GetHighestFinalBlockNonce()) assert.Equal(t, []byte("hash"), bfd.GetHighestFinalBlockHash()) } + +func TestBaseForkDetector_GetNotarizedHeaderHash(t *testing.T) { + t.Parallel() + + roundHandlerMock := &mock.RoundHandlerMock{} + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag != common.AndromedaFlag + }, + }, + &dataRetriever.ProofsPoolMock{}, + ) + + roundHandlerMock.RoundIndex = 10 + _ = bfd.AddHeader( + &block.MetaBlock{PubKeysBitmap: []byte("X"), Nonce: 8, Round: 10}, + []byte("hash0"), + process.BHReceived, + nil, + nil) + + roundHandlerMock.RoundIndex = 11 + _ = bfd.AddHeader( + &block.MetaBlock{PubKeysBitmap: []byte("X"), Nonce: 9, Round: 11}, + []byte("hash1"), + process.BHProcessed, + nil, + nil) + + roundHandlerMock.RoundIndex = 11 + _ = bfd.AddHeader( + &block.MetaBlock{PubKeysBitmap: []byte("X"), Nonce: 9, Round: 11}, + []byte("hash1"), + process.BHNotarized, + nil, + nil) + + hash := bfd.GetNotarizedHeaderHash(7) + assert.Nil(t, hash) + + hash = bfd.GetNotarizedHeaderHash(8) + assert.Nil(t, hash) + + hash = bfd.GetNotarizedHeaderHash(9) + assert.Equal(t, []byte("hash1"), hash) +} + +func TestBaseForkDetector_ReceivedProof(t *testing.T) { + t.Parallel() + + roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 13} + enableEpochsHandlerStub := &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + enableEpochsHandlerStub, + &dataRetriever.ProofsPoolMock{}, + ) + + proof := &processMocks.HeaderProofHandlerStub{ + GetHeaderNonceCalled: func() uint64 { + return 10 + }, + GetHeaderEpochCalled: func() uint32 { + return 1 + }, + GetHeaderRoundCalled: func() uint64 { + return 12 + }, + GetHeaderHashCalled: func() []byte { + return []byte("hash1") + }, + } + bfd.ReceivedProof(proof) + + hdrInfos := bfd.GetHeaders(10) + assert.Len(t, hdrInfos, 1) + assert.Equal(t, []byte("hash1"), hdrInfos[0].Hash()) + + assert.Equal(t, uint64(10), bfd.ProbableHighestNonce()) + + proof2 := &processMocks.HeaderProofHandlerStub{ + GetHeaderNonceCalled: func() uint64 { + return 10 + }, + GetHeaderEpochCalled: func() uint32 { + return 1 + }, + GetHeaderRoundCalled: func() uint64 { + return 13 + }, + GetHeaderHashCalled: func() []byte { + return []byte("hash2") + }, + } + bfd.ReceivedProof(proof2) + + hdrInfos2 := bfd.GetHeaders(10) + assert.Len(t, hdrInfos2, 2) + assert.Equal(t, []byte("hash2"), hdrInfos2[1].Hash()) + + assert.Equal(t, uint64(10), bfd.ProbableHighestNonce()) +} + +func TestBaseForkDetector_BlockWithoutProofShouldReturnEarly(t *testing.T) { + t.Parallel() + + roundHandlerMock := &mock.RoundHandlerMock{} + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + }, + &dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return false + }, + }, + ) + + roundHandlerMock.RoundIndex = 10 + _ = bfd.AddHeader( + &block.MetaBlock{PubKeysBitmap: []byte("X"), Nonce: 8, Round: 10}, + []byte("hash0"), + process.BHReceived, + nil, + nil) + assert.Equal(t, uint64(0), bfd.ProbableHighestNonce()) + + roundHandlerMock.RoundIndex = 11 + _ = bfd.AddHeader( + &block.MetaBlock{PubKeysBitmap: []byte("X"), Nonce: 9, Round: 11}, + []byte("hash1"), + process.BHProcessed, + nil, + nil) + assert.Equal(t, uint64(0), bfd.ProbableHighestNonce()) +} + +func TestBaseForkDetector_ReceivedProofForBlockHeaderShouldSetProof(t *testing.T) { + t.Parallel() + + bfd, _ := sync.NewShardForkDetector( + &mock.RoundHandlerMock{}, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + }, + &dataRetriever.ProofsPoolMock{}, + ) + + hdrInfo := &sync.HeaderInfo{ + Epoch: 1, + Round: 0, + Nonce: 0, + Hash: []byte("hash0"), + State: process.BHProcessed, + HasProof: false, + } + bfd.Append(hdrInfo) + + hdrInfos := bfd.GetHeaders(0) + assert.Len(t, hdrInfos, 1) + assert.Equal(t, []byte("hash0"), hdrInfos[0].Hash()) + assert.Equal(t, false, hdrInfos[0].HasProof()) + + proof := &processMocks.HeaderProofHandlerStub{ + GetHeaderEpochCalled: func() uint32 { + return 1 + }, + GetHeaderRoundCalled: func() uint64 { + return 1 + }, + GetHeaderNonceCalled: func() uint64 { + return 0 + }, + GetHeaderHashCalled: func() []byte { + return []byte("hash0") + }, + } + bfd.ReceivedProof(proof) + + hdrInfos = bfd.GetHeaders(0) + assert.Len(t, hdrInfos, 2) + assert.Equal(t, []byte("hash0"), hdrInfos[0].Hash()) + assert.Equal(t, true, hdrInfos[0].HasProof()) + assert.Equal(t, []byte("hash0"), hdrInfos[1].Hash()) + assert.Equal(t, true, hdrInfos[1].HasProof()) +} diff --git a/process/sync/baseSync.go b/process/sync/baseSync.go index aa43d8cecc1..f66330346df 100644 --- a/process/sync/baseSync.go +++ b/process/sync/baseSync.go @@ -3,6 +3,7 @@ package sync import ( "bytes" "context" + "encoding/hex" "fmt" "math" "sync" @@ -17,6 +18,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -29,7 +32,6 @@ import ( "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/trie/storageMarker" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("process/sync") @@ -57,21 +59,23 @@ type notarizedInfo struct { type baseBootstrap struct { historyRepo dblookupext.HistoryRepository headers dataRetriever.HeadersPool + proofs dataRetriever.ProofsPool chainHandler data.ChainHandler blockProcessor process.BlockProcessor store dataRetriever.StorageService - roundHandler consensus.RoundHandler - hasher hashing.Hasher - marshalizer marshal.Marshalizer - epochHandler dataRetriever.EpochHandler - forkDetector process.ForkDetector - requestHandler process.RequestHandler - shardCoordinator sharding.Coordinator - accounts state.AccountsAdapter - blockBootstrapper blockBootstrapper - blackListHandler process.TimeCacher + roundHandler consensus.RoundHandler + hasher hashing.Hasher + marshalizer marshal.Marshalizer + epochHandler dataRetriever.EpochHandler + forkDetector process.ForkDetector + requestHandler process.RequestHandler + shardCoordinator sharding.Coordinator + accounts state.AccountsAdapter + blockBootstrapper blockBootstrapper + blackListHandler process.TimeCacher + enableEpochsHandler common.EnableEpochsHandler mutHeader sync.RWMutex headerNonce *uint64 @@ -104,8 +108,7 @@ type baseBootstrap struct { requestMiniBlocks func(headerHandler data.HeaderHandler) - networkWatcher process.NetworkConnectionWatcher - getHeaderFromPool func([]byte) (data.HeaderHandler, error) + networkWatcher process.NetworkConnectionWatcher headerStore storage.Storer headerNonceHashStore storage.Storer @@ -158,6 +161,67 @@ func (boot *baseBootstrap) requestedHeaderHash() []byte { return boot.headerhash } +func (boot *baseBootstrap) processReceivedProof(headerProof data.HeaderProofHandler) { + if boot.shardCoordinator.SelfId() != headerProof.GetHeaderShardId() { + return + } + + boot.forkDetector.ReceivedProof(headerProof) + + boot.checkProofCorrespondsToRequestedHash(headerProof) + boot.checkProofCorrespondsToRequestedNonce(headerProof) +} + +func (boot *baseBootstrap) checkProofCorrespondsToRequestedHash(headerProof data.HeaderProofHandler) { + boot.mutRcvHdrHash.RLock() + hash := boot.requestedHeaderHash() + wasHashRequested := hash != nil && bytes.Equal(hash, headerProof.GetHeaderHash()) + if !wasHashRequested { + boot.mutRcvHdrHash.RUnlock() + return + } + + // if header is also received, release the chan and set requested to nil + // otherwise wait for the header + _, err := boot.headers.GetHeaderByHash(headerProof.GetHeaderHash()) + hasHeader := err == nil + if hasHeader { + boot.setRequestedHeaderHash(nil) + boot.mutRcvHdrHash.RUnlock() + + boot.chRcvHdrHash <- true + + return + } + + boot.mutRcvHdrHash.RUnlock() +} + +func (boot *baseBootstrap) checkProofCorrespondsToRequestedNonce(headerProof data.HeaderProofHandler) { + boot.mutRcvHdrNonce.RLock() + n := boot.requestedHeaderNonce() + wasNonceRequested := n != nil && *n == headerProof.GetHeaderNonce() + if !wasNonceRequested { + boot.mutRcvHdrNonce.RUnlock() + return + } + + // if header is also received, release the chan and set requested to nil + // otherwise wait for the header + _, err := boot.headers.GetHeaderByHash(headerProof.GetHeaderHash()) + hasHeader := err == nil + if hasHeader { + boot.setRequestedHeaderNonce(nil) + boot.mutRcvHdrNonce.RUnlock() + + boot.chRcvHdrNonce <- true + + return + } + + boot.mutRcvHdrNonce.RUnlock() +} + func (boot *baseBootstrap) processReceivedHeader(headerHandler data.HeaderHandler, headerHash []byte) { if boot.shardCoordinator.SelfId() != headerHandler.GetShardID() { return @@ -191,9 +255,36 @@ func (boot *baseBootstrap) confirmHeaderReceivedByNonce(headerHandler data.Heade "nonce", headerHandler.GetNonce(), "hash", hdrHash, ) - boot.setRequestedHeaderNonce(nil) + + // if flag is not active for the header, do not check the proof and release chan + isFlagActive := common.IsProofsFlagEnabledForHeader(boot.enableEpochsHandler, headerHandler) + if !isFlagActive { + boot.setRequestedHeaderNonce(nil) + boot.mutRcvHdrNonce.Unlock() + + boot.chRcvHdrNonce <- true + + return + } + + // if proof is also received, release chan and set requested to nil + // otherwise, wait for the proof too + hasProof := boot.proofs.HasProof(headerHandler.GetShardID(), hdrHash) + if hasProof { + log.Debug("received requested proof from network", + "shard", headerHandler.GetShardID(), + "round", headerHandler.GetRound(), + "nonce", headerHandler.GetNonce(), + "hash", hdrHash, + ) + boot.setRequestedHeaderNonce(nil) + } boot.mutRcvHdrNonce.Unlock() - boot.chRcvHdrNonce <- true + + if hasProof { + boot.chRcvHdrNonce <- true + } + return } @@ -210,15 +301,50 @@ func (boot *baseBootstrap) confirmHeaderReceivedByHash(headerHandler data.Header "nonce", headerHandler.GetNonce(), "hash", hash, ) - boot.setRequestedHeaderHash(nil) + + // if flag is not active for the header, do not check the proof and release chan + isFlagActive := common.IsProofsFlagEnabledForHeader(boot.enableEpochsHandler, headerHandler) + if !isFlagActive { + boot.setRequestedHeaderHash(nil) + boot.mutRcvHdrHash.Unlock() + + boot.chRcvHdrHash <- true + + return + } + + // if proof is also received, release chan and set requested to nil + // otherwise, wait for the proof too + hasProof := boot.proofs.HasProof(headerHandler.GetShardID(), hash) + if hasProof { + log.Debug("received requested proof from network", + "shard", headerHandler.GetShardID(), + "round", headerHandler.GetRound(), + "nonce", headerHandler.GetNonce(), + "hash", hash, + ) + boot.setRequestedHeaderHash(nil) + } boot.mutRcvHdrHash.Unlock() - boot.chRcvHdrHash <- true + + if hasProof { + boot.chRcvHdrHash <- true + } return } + boot.mutRcvHdrHash.Unlock() } +func (boot *baseBootstrap) hasProof(hash []byte, header data.HeaderHandler) bool { + if !common.IsProofsFlagEnabledForHeader(boot.enableEpochsHandler, header) { + return true + } + + return boot.proofs.HasProof(boot.shardCoordinator.SelfId(), hash) +} + // AddSyncStateListener adds a syncStateListener that get notified each time the sync status of the node changes func (boot *baseBootstrap) AddSyncStateListener(syncStateListener func(isSyncing bool)) { boot.mutSyncStateListeners.Lock() @@ -264,8 +390,8 @@ func (boot *baseBootstrap) getEpochOfCurrentBlock() uint32 { return epoch } -// waitForHeaderNonce method wait for header with the requested nonce to be received -func (boot *baseBootstrap) waitForHeaderNonce() error { +// waitForHeaderAndProofByNonce method wait for header with the requested nonce to be received +func (boot *baseBootstrap) waitForHeaderAndProofByNonce() error { select { case <-boot.chRcvHdrNonce: return nil @@ -275,7 +401,7 @@ func (boot *baseBootstrap) waitForHeaderNonce() error { } // waitForHeaderHash method wait for header with the requested hash to be received -func (boot *baseBootstrap) waitForHeaderHash() error { +func (boot *baseBootstrap) waitForHeaderAndProofByHash() error { select { case <-boot.chRcvHdrHash: return nil @@ -491,6 +617,9 @@ func checkBaseBootstrapParameters(arguments ArgBaseBootstrapper) error { if arguments.ProcessWaitTime < minimumProcessWaitTime { return fmt.Errorf("%w, minimum is %v, provided is %v", process.ErrInvalidProcessWaitTime, minimumProcessWaitTime, arguments.ProcessWaitTime) } + if check.IfNil(arguments.EnableEpochsHandler) { + return process.ErrNilEnableEpochsHandler + } return nil } @@ -626,6 +755,8 @@ func (boot *baseBootstrap) syncBlock() error { defer func() { if err != nil { + log.Warn("sync block failed", "error", err) + boot.doJobOnSyncBlockFail(body, header, err) } }() @@ -687,6 +818,7 @@ func (boot *baseBootstrap) syncBlock() error { ) boot.cleanNoncesSyncedWithErrorsBehindFinal() + boot.cleanProofsBehindFinal(header) return nil } @@ -715,6 +847,24 @@ func (boot *baseBootstrap) cleanNoncesSyncedWithErrorsBehindFinal() { } } +func (boot *baseBootstrap) cleanProofsBehindFinal(header data.HeaderHandler) { + if !boot.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + return + } + + finalNonce := boot.forkDetector.GetHighestFinalBlockNonce() + + err := boot.proofs.CleanupProofsBehindNonce(header.GetShardID(), finalNonce) + if err != nil { + log.Warn("failed to cleanup notarized proofs behind nonce", + "nonce", finalNonce, + "shardID", header.GetShardID(), + "error", err) + } + + log.Trace("baseBootstrap.cleanProofsBehindFinal cleanup successfully", "finalNonce", finalNonce) +} + // rollBack decides if rollBackOneBlock must be called func (boot *baseBootstrap) rollBack(revertUsingForkNonce bool) error { var roleBackOneBlockExecuted bool @@ -947,10 +1097,247 @@ func (boot *baseBootstrap) getNextHeaderRequestingIfMissing() (data.HeaderHandle } if hash != nil { - return boot.blockBootstrapper.getHeaderWithHashRequestingIfMissing(hash) + header, err := boot.getHeaderWithHashRequestingIfMissing(hash) + return header, err + } + + return boot.getHeaderWithNonceRequestingIfMissing(nonce) +} + +// getHeaderWithHashRequestingIfMissing method gets the header with a given hash from pool. If it is not found there, +// it will be requested from network +func (boot *baseBootstrap) getHeaderWithHashRequestingIfMissing(hash []byte) (data.HeaderHandler, error) { + hdr, err := boot.getHeader(hash) + hasHeader := err == nil + needsProof := boot.checkNeedsProofByHash(hash, hdr) + if hasHeader && !needsProof { + return hdr, nil + } + + boot.requestHeaderAndProofByHashIfMissing(hash, !hasHeader, needsProof) + + err = boot.waitForHeaderAndProofByHash() + if err != nil { + return nil, err + } + + hdr, err = boot.getHeaderFromPool(hash) + if err != nil { + return nil, err + } + + if !boot.hasProof(hash, hdr) { + return nil, process.ErrMissingHeaderProof + } + + return hdr, nil +} + +func (boot *baseBootstrap) checkNeedsProofByHash(hash []byte, header data.HeaderHandler) bool { + // if header exists, check if it has or needs a proof + // if it has a proof, do not wait + // if it does not need a proof, do not wait + // if it needs a proof, request and wait for the proof + // if header does not exist + // if it has a proof, request the header + // if it does not have the proof, request both and decide when header is received if it truly needed the proof + _, errGetProof := boot.proofs.GetProof(boot.shardCoordinator.SelfId(), hash) + hasProof := errGetProof == nil + needsProof := !hasProof + if check.IfNil(header) { + return needsProof + } + + isFlagActiveForExistingHeader := common.IsProofsFlagEnabledForHeader(boot.enableEpochsHandler, header) + needsProof = needsProof && isFlagActiveForExistingHeader + return needsProof +} + +// getHeaderWithNonceRequestingIfMissing method gets the header with a given nonce from pool. If it is not found there, it will +// be requested from network +func (boot *baseBootstrap) getHeaderWithNonceRequestingIfMissing(nonce uint64) (data.HeaderHandler, error) { + hdr, hash, err := boot.getHeaderFromPoolWithNonce(nonce) + hasHeader := err == nil + + if hasHeader && boot.hasProof(hash, hdr) { + return hdr, nil + } + + needsProof := boot.checkNeedsProofByNonce(nonce, hdr, hash) + + if hasHeader { + boot.requestHandler.SetEpoch(hdr.GetEpoch()) + } + + boot.requestHeaderAndProofByNonceIfMissing(hash, nonce, !hasHeader, needsProof) + + err = boot.waitForHeaderAndProofByNonce() + if err != nil { + return nil, err + } + + hdr, hash, err = boot.getHeaderFromPoolWithNonce(nonce) + if err != nil { + return nil, err + } + + if !boot.hasProof(hash, hdr) { + return nil, process.ErrMissingHeaderProof + } + + return hdr, nil +} + +func (boot *baseBootstrap) checkNeedsProofByNonce( + nonce uint64, + header data.HeaderHandler, + headerHash []byte, +) bool { + // if header exists, check if it has or needs a proof + // if it has a proof, do not wait + // if it does not need a proof, do not wait + // if it needs a proof, request and wait for the proof + // if header does not exist + // if it has a proof, request the header + // if it does not have the proof, request both and decide when header is received if it truly needed the proof + proof, errGetProof := boot.proofs.GetProofByNonce(nonce, boot.shardCoordinator.SelfId()) + hasProof := errGetProof == nil + needsProof := !hasProof + + if check.IfNil(header) { + return needsProof + } + + if hasProof && !bytes.Equal(headerHash, proof.GetHeaderHash()) { + needsProof = true + } + + isFlagActiveForExistingHeader := common.IsProofsFlagEnabledForHeader(boot.enableEpochsHandler, header) + needsProof = needsProof && isFlagActiveForExistingHeader + + return needsProof +} + +func (boot *baseBootstrap) requestHeaderAndProofByHashIfMissing( + hash []byte, + needsHeader bool, + needsProof bool, +) { + _ = core.EmptyChannel(boot.chRcvHdrHash) + if needsHeader { + boot.setRequestedHeaderHash(hash) + boot.requestHeaderByHash(hash) + } + + if !needsProof { + return + } + + log.Debug("requesting equivalent proof from network", + "hash", hex.EncodeToString(hash), + ) + + boot.setRequestedHeaderHash(hash) + boot.requestHandler.RequestEquivalentProofByHash(boot.shardCoordinator.SelfId(), hash) +} + +func (boot *baseBootstrap) requestHeaderByHash(hash []byte) { + logMsg := fmt.Sprintf("requesting %s header from network", boot.getShardLabel()) + log.Debug(logMsg, + "hash", hash, + "probable highest nonce", boot.forkDetector.ProbableHighestNonce(), + ) + + if boot.shardCoordinator.SelfId() == core.MetachainShardId { + boot.requestHandler.RequestMetaHeader(hash) + return + } + + boot.requestHandler.RequestShardHeader(boot.shardCoordinator.SelfId(), hash) +} + +func (boot *baseBootstrap) getShardLabel() string { + shardLabel := "meta" + if boot.shardCoordinator.SelfId() != core.MetachainShardId { + shardLabel = "shard" + } + + return shardLabel +} + +func (boot *baseBootstrap) requestHeaderAndProofByNonceIfMissing( + hash []byte, + nonce uint64, + needsHeader bool, + needsProof bool, +) { + _ = core.EmptyChannel(boot.chRcvHdrNonce) + if needsHeader { + boot.setRequestedHeaderNonce(&nonce) + boot.requestHeaderByNonce(nonce) + } + + if !needsProof { + return + } + + if len(hash) == 0 { + log.Debug("requesting equivalent proof from network", + "nonce", nonce, + ) + + boot.setRequestedHeaderNonce(&nonce) + boot.requestHandler.RequestEquivalentProofByNonce(boot.shardCoordinator.SelfId(), nonce) + return + } + + log.Debug("requesting equivalent proof from network", + "hash", hex.EncodeToString(hash), + ) + + boot.setRequestedHeaderNonce(&nonce) + boot.requestHandler.RequestEquivalentProofByHash(boot.shardCoordinator.SelfId(), hash) +} + +func (boot *baseBootstrap) requestHeaderByNonce(nonce uint64) { + logMsg := fmt.Sprintf("requesting %s header by nonce from network", boot.getShardLabel()) + log.Debug(logMsg, + "nonce", nonce, + "probable highest nonce", boot.forkDetector.ProbableHighestNonce(), + ) + + if boot.shardCoordinator.SelfId() == core.MetachainShardId { + boot.requestHandler.RequestMetaHeaderByNonce(nonce) + return } - return boot.blockBootstrapper.getHeaderWithNonceRequestingIfMissing(nonce) + boot.requestHandler.RequestShardHeaderByNonce(boot.shardCoordinator.SelfId(), nonce) +} + +func (boot *baseBootstrap) getHeader(hash []byte) (data.HeaderHandler, error) { + if boot.shardCoordinator.SelfId() == core.MetachainShardId { + return process.GetMetaHeader(hash, boot.headers, boot.marshalizer, boot.store) + } + + return process.GetShardHeader(hash, boot.headers, boot.marshalizer, boot.store) +} + +func (boot *baseBootstrap) getHeaderFromPool(hash []byte) (data.HeaderHandler, error) { + if boot.shardCoordinator.SelfId() == core.MetachainShardId { + return process.GetMetaHeaderFromPool(hash, boot.headers) + } + + return process.GetShardHeaderFromPool(hash, boot.headers) +} + +func (boot *baseBootstrap) getHeaderFromPoolWithNonce( + nonce uint64, +) (data.HeaderHandler, []byte, error) { + if boot.shardCoordinator.SelfId() == core.MetachainShardId { + return process.GetMetaHeaderFromPoolWithNonce(nonce, boot.headers) + } + + return process.GetShardHeaderFromPoolWithNonce(nonce, boot.shardCoordinator.SelfId(), boot.headers) } func (boot *baseBootstrap) isForcedRollBackOneBlock() bool { @@ -1148,6 +1535,7 @@ func (boot *baseBootstrap) init() { boot.poolsHolder.MiniBlocks().RegisterHandler(boot.receivedMiniblock, core.UniqueIdentifier()) boot.headers.RegisterHandler(boot.processReceivedHeader) + boot.proofs.RegisterHandler(boot.processReceivedProof) boot.syncStateListeners = make([]func(bool), 0) boot.requestedHashes = process.RequiredDataPool{} @@ -1159,12 +1547,24 @@ func (boot *baseBootstrap) requestHeaders(fromNonce uint64, toNonce uint64) { defer boot.mutRequestHeaders.Unlock() for currentNonce := fromNonce; currentNonce <= toNonce; currentNonce++ { - haveHeader := boot.blockBootstrapper.haveHeaderInPoolWithNonce(currentNonce) - if haveHeader { + hdr, hash, err := boot.getHeaderFromPoolWithNonce(currentNonce) + hasHeader := err == nil + if hasHeader && boot.hasProof(hash, hdr) { continue } - boot.blockBootstrapper.requestHeaderByNonce(currentNonce) + if hasHeader { + boot.requestHandler.SetEpoch(hdr.GetEpoch()) + } + + needsProof := boot.checkNeedsProofByNonce(currentNonce, hdr, hash) + if !hasHeader { + boot.blockBootstrapper.requestHeaderByNonce(currentNonce) + } + + if needsProof { + boot.blockBootstrapper.requestProofByNonce(currentNonce) + } } } diff --git a/process/sync/errors.go b/process/sync/errors.go index c33db506b65..f77aa404c55 100644 --- a/process/sync/errors.go +++ b/process/sync/errors.go @@ -23,7 +23,7 @@ var ErrLowerRoundInBlock = errors.New("lower round in block") // ErrHigherRoundInBlock signals that the round index in block is higher than the current round of chronology var ErrHigherRoundInBlock = errors.New("higher round in block") -//ErrCorruptBootstrapFromStorageDb signals that the bootstrap database is corrupt +// ErrCorruptBootstrapFromStorageDb signals that the bootstrap database is corrupt var ErrCorruptBootstrapFromStorageDb = errors.New("corrupt bootstrap storage database") // ErrSignedBlock signals that a block is signed diff --git a/process/sync/export_test.go b/process/sync/export_test.go index 719e7599f9f..f8f172b733e 100644 --- a/process/sync/export_test.go +++ b/process/sync/export_test.go @@ -3,6 +3,7 @@ package sync import ( "time" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-go/common" @@ -10,8 +11,13 @@ import ( ) // RequestHeaderWithNonce - -func (boot *ShardBootstrap) RequestHeaderWithNonce(nonce uint64) { - boot.requestHeaderWithNonce(nonce) +func (boot *baseBootstrap) RequestHeaderWithNonce(nonce uint64) { + if boot.shardCoordinator.SelfId() == core.MetachainShardId { + boot.requestHandler.RequestMetaHeaderByNonce(nonce) + return + } + + boot.requestHandler.RequestShardHeaderByNonce(boot.shardCoordinator.SelfId(), nonce) } // GetMiniBlocks - @@ -24,11 +30,41 @@ func (boot *MetaBootstrap) ReceivedHeaders(header data.HeaderHandler, key []byte boot.processReceivedHeader(header, key) } +// ReceivedProof - +func (boot *MetaBootstrap) ReceivedProof(header data.HeaderProofHandler) { + boot.processReceivedProof(header) +} + +// SetRcvHdrNonce - +func (boot *MetaBootstrap) SetRcvHdrNonce() { + boot.chRcvHdrNonce <- true +} + +// SetRcvHdrHash - +func (boot *MetaBootstrap) SetRcvHdrHash() { + boot.chRcvHdrHash <- true +} + // ReceivedHeaders - func (boot *ShardBootstrap) ReceivedHeaders(header data.HeaderHandler, key []byte) { boot.processReceivedHeader(header, key) } +// ReceivedProof - +func (boot *ShardBootstrap) ReceivedProof(header data.HeaderProofHandler) { + boot.processReceivedProof(header) +} + +// SetRcvHdrNonce - +func (boot *ShardBootstrap) SetRcvHdrNonce() { + boot.chRcvHdrNonce <- true +} + +// SetRcvHdrHash - +func (boot *ShardBootstrap) SetRcvHdrHash() { + boot.chRcvHdrHash <- true +} + // RollBack - func (boot *ShardBootstrap) RollBack(revertUsingForkNonce bool) error { return boot.rollBack(revertUsingForkNonce) @@ -106,11 +142,39 @@ func (bfd *baseForkDetector) IsConsensusStuck() bool { return bfd.isConsensusStuck() } +// Append - +func (bfd *baseForkDetector) Append(hdrInfo *HeaderInfo) bool { + hdr := &headerInfo{ + epoch: hdrInfo.Epoch, + hash: hdrInfo.Hash, + nonce: hdrInfo.Nonce, + round: hdrInfo.Round, + state: hdrInfo.State, + hasProof: hdrInfo.HasProof, + } + return bfd.append(hdr) +} + +// HeaderInfo - +type HeaderInfo struct { + Epoch uint32 + Nonce uint64 + Round uint64 + Hash []byte + State process.BlockHeaderState + HasProof bool +} + // Hash - func (hi *headerInfo) Hash() []byte { return hi.hash } +// HasProof - +func (hi *headerInfo) HasProof() bool { + return hi.hasProof +} + // GetBlockHeaderState - func (hi *headerInfo) GetBlockHeaderState() process.BlockHeaderState { return hi.state diff --git a/process/sync/interface.go b/process/sync/interface.go index 88f644df160..e6965337ea1 100644 --- a/process/sync/interface.go +++ b/process/sync/interface.go @@ -4,6 +4,7 @@ import ( "context" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/storage" ) @@ -14,10 +15,10 @@ type blockBootstrapper interface { getBlockBody(headerHandler data.HeaderHandler) (data.BodyHandler, error) getHeaderWithHashRequestingIfMissing(hash []byte) (data.HeaderHandler, error) getHeaderWithNonceRequestingIfMissing(nonce uint64) (data.HeaderHandler, error) - haveHeaderInPoolWithNonce(nonce uint64) bool getBlockBodyRequestingIfMissing(headerHandler data.HeaderHandler) (data.BodyHandler, error) isForkTriggeredByMeta() bool requestHeaderByNonce(nonce uint64) + requestProofByNonce(nonce uint64) } // syncStarter defines the behavior of component that can start sync-ing blocks diff --git a/process/sync/metaForkDetector.go b/process/sync/metaForkDetector.go index 178e4e96042..991bfee7140 100644 --- a/process/sync/metaForkDetector.go +++ b/process/sync/metaForkDetector.go @@ -6,6 +6,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/process" ) @@ -23,6 +25,8 @@ func NewMetaForkDetector( blackListHandler process.TimeCacher, blockTracker process.BlockTracker, genesisTime int64, + enableEpochsHandler common.EnableEpochsHandler, + proofsPool process.ProofsPool, ) (*metaForkDetector, error) { if check.IfNil(roundHandler) { @@ -34,6 +38,12 @@ func NewMetaForkDetector( if check.IfNil(blockTracker) { return nil, process.ErrNilBlockTracker } + if check.IfNil(enableEpochsHandler) { + return nil, process.ErrNilEnableEpochsHandler + } + if check.IfNil(proofsPool) { + return nil, process.ErrNilProofsPool + } genesisHdr, _, err := blockTracker.GetSelfNotarizedHeader(core.MetachainShardId, 0) if err != nil { @@ -41,13 +51,15 @@ func NewMetaForkDetector( } bfd := &baseForkDetector{ - roundHandler: roundHandler, - blackListHandler: blackListHandler, - genesisTime: genesisTime, - blockTracker: blockTracker, - genesisNonce: genesisHdr.GetNonce(), - genesisRound: genesisHdr.GetRound(), - genesisEpoch: genesisHdr.GetEpoch(), + roundHandler: roundHandler, + blackListHandler: blackListHandler, + genesisTime: genesisTime, + blockTracker: blockTracker, + genesisNonce: genesisHdr.GetNonce(), + genesisRound: genesisHdr.GetRound(), + genesisEpoch: genesisHdr.GetEpoch(), + enableEpochsHandler: enableEpochsHandler, + proofsPool: proofsPool, } bfd.headers = make(map[uint64][]*headerInfo) @@ -96,7 +108,11 @@ func (mfd *metaForkDetector) doJobOnBHProcessed( _ [][]byte, ) { mfd.setFinalCheckpoint(mfd.lastCheckpoint()) - mfd.addCheckpoint(&checkpointInfo{nonce: header.GetNonce(), round: header.GetRound(), hash: headerHash}) + newCheckpoint := &checkpointInfo{nonce: header.GetNonce(), round: header.GetRound(), hash: headerHash} + mfd.addCheckpoint(newCheckpoint) + if common.IsProofsFlagEnabledForHeader(mfd.enableEpochsHandler, header) { + mfd.setFinalCheckpoint(newCheckpoint) + } mfd.removePastOrInvalidRecords() } diff --git a/process/sync/metaForkDetector_test.go b/process/sync/metaForkDetector_test.go index 5db5855c6a4..da308ae35cb 100644 --- a/process/sync/metaForkDetector_test.go +++ b/process/sync/metaForkDetector_test.go @@ -5,10 +5,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/sync" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/stretchr/testify/assert" ) @@ -20,6 +25,8 @@ func TestNewMetaForkDetector_NilRoundHandlerShouldErr(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.True(t, check.IfNil(sfd)) assert.Equal(t, process.ErrNilRoundHandler, err) @@ -33,6 +40,8 @@ func TestNewMetaForkDetector_NilBlackListShouldErr(t *testing.T) { nil, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.True(t, check.IfNil(sfd)) assert.Equal(t, process.ErrNilBlackListCacher, err) @@ -46,11 +55,43 @@ func TestNewMetaForkDetector_NilBlockTrackerShouldErr(t *testing.T) { &testscommon.TimeCacheStub{}, nil, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.True(t, check.IfNil(sfd)) assert.Equal(t, process.ErrNilBlockTracker, err) } +func TestNewMetaForkDetector_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + sfd, err := sync.NewMetaForkDetector( + &mock.RoundHandlerMock{}, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + nil, + &dataRetriever.ProofsPoolMock{}, + ) + assert.True(t, check.IfNil(sfd)) + assert.Equal(t, process.ErrNilEnableEpochsHandler, err) +} + +func TestNewMetaForkDetector_NilProofsPoolShouldErr(t *testing.T) { + t.Parallel() + + sfd, err := sync.NewMetaForkDetector( + &mock.RoundHandlerMock{}, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + nil, + ) + assert.True(t, check.IfNil(sfd)) + assert.Equal(t, process.ErrNilProofsPool, err) +} + func TestNewMetaForkDetector_OkParamsShouldWork(t *testing.T) { t.Parallel() @@ -59,6 +100,8 @@ func TestNewMetaForkDetector_OkParamsShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, ) assert.Nil(t, err) assert.False(t, check.IfNil(sfd)) @@ -73,7 +116,14 @@ func TestMetaForkDetector_AddHeaderNilHeaderShouldErr(t *testing.T) { t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 100} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) err := bfd.AddHeader(nil, make([]byte, 0), process.BHProcessed, nil, nil) assert.Equal(t, sync.ErrNilHeader, err) } @@ -82,7 +132,14 @@ func TestMetaForkDetector_AddHeaderNilHashShouldErr(t *testing.T) { t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 100} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) err := bfd.AddHeader(&block.Header{}, nil, process.BHProcessed, nil, nil) assert.Equal(t, sync.ErrNilHash, err) } @@ -93,7 +150,14 @@ func TestMetaForkDetector_AddHeaderNotPresentShouldWork(t *testing.T) { hdr := &block.Header{Nonce: 1, Round: 1, PubKeysBitmap: []byte("X")} hash := make([]byte, 0) roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 1} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) err := bfd.AddHeader(hdr, hash, process.BHProcessed, nil, nil) assert.Nil(t, err) @@ -111,7 +175,14 @@ func TestMetaForkDetector_AddHeaderPresentShouldAppend(t *testing.T) { hdr2 := &block.Header{Nonce: 1, Round: 1, PubKeysBitmap: []byte("X")} hash2 := []byte("hash2") roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 1} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) _ = bfd.AddHeader(hdr1, hash1, process.BHProcessed, nil, nil) err := bfd.AddHeader(hdr2, hash2, process.BHProcessed, nil, nil) @@ -129,11 +200,44 @@ func TestMetaForkDetector_AddHeaderWithProcessedBlockShouldSetCheckpoint(t *test hdr1 := &block.Header{Nonce: 69, Round: 72, PubKeysBitmap: []byte("X")} hash1 := []byte("hash1") roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 73} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) _ = bfd.AddHeader(hdr1, hash1, process.BHProcessed, nil, nil) assert.Equal(t, hdr1.Nonce, bfd.LastCheckpointNonce()) } +func TestMetaForkDetector_AddHeaderWithProcessedBlockAndFlagShouldSetCheckpoint(t *testing.T) { + t.Parallel() + + hdr1 := &block.Header{Nonce: 23, Round: 25, PubKeysBitmap: []byte("X")} + hash1 := []byte("hash1") + roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 26} + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + }, + &dataRetriever.ProofsPoolMock{ + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true + }, + }, + ) + _ = bfd.AddHeader(hdr1, hash1, process.BHProcessed, nil, nil) + assert.Equal(t, hdr1.Nonce, bfd.FinalCheckpointNonce()) +} + func TestMetaForkDetector_AddHeaderPresentShouldNotRewriteState(t *testing.T) { t.Parallel() @@ -141,7 +245,14 @@ func TestMetaForkDetector_AddHeaderPresentShouldNotRewriteState(t *testing.T) { hash := []byte("hash1") hdr2 := &block.Header{Nonce: 1, Round: 1, PubKeysBitmap: []byte("X")} roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 1} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) _ = bfd.AddHeader(hdr1, hash, process.BHReceived, nil, nil) err := bfd.AddHeader(hdr2, hash, process.BHProcessed, nil, nil) @@ -158,7 +269,14 @@ func TestMetaForkDetector_AddHeaderHigherNonceThanRoundShouldErr(t *testing.T) { t.Parallel() roundHandlerMock := &mock.RoundHandlerMock{RoundIndex: 100} - bfd, _ := sync.NewMetaForkDetector(roundHandlerMock, &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0) + bfd, _ := sync.NewMetaForkDetector( + roundHandlerMock, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetriever.ProofsPoolMock{}, + ) err := bfd.AddHeader( &block.Header{Nonce: 1, Round: 0, PubKeysBitmap: []byte("X")}, []byte("hash1"), diff --git a/process/sync/metablock.go b/process/sync/metablock.go index 1b3c69c7386..b821bfe4d3c 100644 --- a/process/sync/metablock.go +++ b/process/sync/metablock.go @@ -31,6 +31,9 @@ func NewMetaBootstrap(arguments ArgMetaBootstrapper) (*MetaBootstrap, error) { if check.IfNil(arguments.PoolsHolder.Headers()) { return nil, process.ErrNilMetaBlocksPool } + if check.IfNil(arguments.PoolsHolder.Proofs()) { + return nil, process.ErrNilProofsPool + } if check.IfNil(arguments.EpochBootstrapper) { return nil, process.ErrNilEpochStartTrigger } @@ -54,6 +57,7 @@ func NewMetaBootstrap(arguments ArgMetaBootstrapper) (*MetaBootstrap, error) { blockProcessor: arguments.BlockProcessor, store: arguments.Store, headers: arguments.PoolsHolder.Headers(), + proofs: arguments.PoolsHolder.Proofs(), roundHandler: arguments.RoundHandler, waitTime: arguments.WaitTime, hasher: arguments.Hasher, @@ -78,6 +82,7 @@ func NewMetaBootstrap(arguments ArgMetaBootstrapper) (*MetaBootstrap, error) { historyRepo: arguments.HistoryRepo, scheduledTxsExecutionHandler: arguments.ScheduledTxsExecutionHandler, processWaitTime: arguments.ProcessWaitTime, + enableEpochsHandler: arguments.EnableEpochsHandler, } if base.isInImportMode { @@ -93,7 +98,6 @@ func NewMetaBootstrap(arguments ArgMetaBootstrapper) (*MetaBootstrap, error) { base.blockBootstrapper = &boot base.syncStarter = &boot - base.getHeaderFromPool = boot.getMetaHeaderFromPool base.requestMiniBlocks = boot.requestMiniBlocksFromHeaderWithNonceIfMissing // placed in struct fields for performance reasons @@ -221,72 +225,6 @@ func (boot *MetaBootstrap) Close() error { return boot.baseBootstrap.Close() } -// requestHeaderWithNonce method requests a block header from network when it is not found in the pool -func (boot *MetaBootstrap) requestHeaderWithNonce(nonce uint64) { - boot.setRequestedHeaderNonce(&nonce) - log.Debug("requesting meta header from network", - "nonce", nonce, - "probable highest nonce", boot.forkDetector.ProbableHighestNonce(), - ) - boot.requestHandler.RequestMetaHeaderByNonce(nonce) -} - -// requestHeaderWithHash method requests a block header from network when it is not found in the pool -func (boot *MetaBootstrap) requestHeaderWithHash(hash []byte) { - boot.setRequestedHeaderHash(hash) - log.Debug("requesting meta header from network", - "hash", hash, - "probable highest nonce", boot.forkDetector.ProbableHighestNonce(), - ) - boot.requestHandler.RequestMetaHeader(hash) -} - -// getHeaderWithNonceRequestingIfMissing method gets the header with a given nonce from pool. If it is not found there, it will -// be requested from network -func (boot *MetaBootstrap) getHeaderWithNonceRequestingIfMissing(nonce uint64) (data.HeaderHandler, error) { - hdr, _, err := process.GetMetaHeaderFromPoolWithNonce( - nonce, - boot.headers) - if err != nil { - _ = core.EmptyChannel(boot.chRcvHdrNonce) - boot.requestHeaderWithNonce(nonce) - err = boot.waitForHeaderNonce() - if err != nil { - return nil, err - } - - hdr, _, err = process.GetMetaHeaderFromPoolWithNonce( - nonce, - boot.headers) - if err != nil { - return nil, err - } - } - - return hdr, nil -} - -// getHeaderWithHashRequestingIfMissing method gets the header with a given hash from pool. If it is not found there, -// it will be requested from network -func (boot *MetaBootstrap) getHeaderWithHashRequestingIfMissing(hash []byte) (data.HeaderHandler, error) { - hdr, err := process.GetMetaHeader(hash, boot.headers, boot.marshalizer, boot.store) - if err != nil { - _ = core.EmptyChannel(boot.chRcvHdrHash) - boot.requestHeaderWithHash(hash) - err = boot.waitForHeaderHash() - if err != nil { - return nil, err - } - - hdr, err = process.GetMetaHeaderFromPool(hash, boot.headers) - if err != nil { - return nil, err - } - } - - return hdr, nil -} - func (boot *MetaBootstrap) getPrevHeader( header data.HeaderHandler, headerStore storage.Storer, @@ -321,18 +259,6 @@ func (boot *MetaBootstrap) getCurrHeader() (data.HeaderHandler, error) { return header, nil } -func (boot *MetaBootstrap) haveHeaderInPoolWithNonce(nonce uint64) bool { - _, _, err := process.GetMetaHeaderFromPoolWithNonce( - nonce, - boot.headers) - - return err == nil -} - -func (boot *MetaBootstrap) getMetaHeaderFromPool(headerHash []byte) (data.HeaderHandler, error) { - return process.GetMetaHeaderFromPool(headerHash, boot.headers) -} - func (boot *MetaBootstrap) getBlockBodyRequestingIfMissing(headerHandler data.HeaderHandler) (data.BodyHandler, error) { header, ok := headerHandler.(*block.MetaBlock) if !ok { @@ -392,6 +318,10 @@ func (boot *MetaBootstrap) requestHeaderByNonce(nonce uint64) { boot.requestHandler.RequestMetaHeaderByNonce(nonce) } +func (boot *MetaBootstrap) requestProofByNonce(nonce uint64) { + boot.requestHandler.RequestEquivalentProofByNonce(core.MetachainShardId, nonce) +} + // IsInterfaceNil returns true if there is no value under the interface func (boot *MetaBootstrap) IsInterfaceNil() bool { return boot == nil diff --git a/process/sync/metablock_test.go b/process/sync/metablock_test.go index 6d183fbf821..8649b23cca1 100644 --- a/process/sync/metablock_test.go +++ b/process/sync/metablock_test.go @@ -15,23 +15,26 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus/round" "github.com/multiversx/mx-chain-go/dataRetriever" - "github.com/multiversx/mx-chain-go/dataRetriever/blockchain" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/sync" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/outport" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMetaBlockProcessor(blk data.ChainHandler) *testscommon.BlockProcessorStub { @@ -53,17 +56,10 @@ func createMetaBlockProcessor(blk data.ChainHandler) *testscommon.BlockProcessor return blockProcessorMock } -func createMetaStore() dataRetriever.StorageService { - store := dataRetriever.NewChainStorer() - store.AddStorer(dataRetriever.MetaBlockUnit, generateTestUnit()) - store.AddStorer(dataRetriever.ShardHdrNonceHashDataUnit, generateTestUnit()) - store.AddStorer(dataRetriever.MetaHdrNonceHashDataUnit, generateTestUnit()) - store.AddStorer(dataRetriever.UserAccountsUnit, generateTestUnit()) - store.AddStorer(dataRetriever.PeerAccountsUnit, generateTestUnit()) - return store -} - func CreateMetaBootstrapMockArguments() sync.ArgMetaBootstrapper { + shardCoordinator := mock.NewOneShardCoordinatorMock() + _ = shardCoordinator.SetSelfId(core.MetachainShardId) + argsBaseBootstrapper := sync.ArgBaseBootstrapper{ PoolsHolder: createMockPools(), Store: createStore(), @@ -75,7 +71,7 @@ func CreateMetaBootstrapMockArguments() sync.ArgMetaBootstrapper { Marshalizer: &mock.MarshalizerMock{}, ForkDetector: &mock.ForkDetectorMock{}, RequestHandler: &testscommon.RequestHandlerStub{}, - ShardCoordinator: mock.NewOneShardCoordinatorMock(), + ShardCoordinator: shardCoordinator, Accounts: &stateMock.AccountsStub{}, BlackListHandler: &testscommon.TimeCacheStub{}, NetworkWatcher: initNetworkWatcher(), @@ -92,6 +88,7 @@ func CreateMetaBootstrapMockArguments() sync.ArgMetaBootstrapper { ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, ProcessWaitTime: testProcessWaitTime, RepopulateTokensSupplies: false, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } argsMetaBootstrapper := sync.ArgMetaBootstrapper{ @@ -170,6 +167,22 @@ func TestNewMetaBootstrap_PoolsHolderRetNilOnHeadersShouldErr(t *testing.T) { assert.Equal(t, process.ErrNilMetaBlocksPool, err) } +func TestNewMetaBootstrap_NilProofsPool(t *testing.T) { + t.Parallel() + + args := CreateMetaBootstrapMockArguments() + pools := createMockPools() + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return nil + } + args.PoolsHolder = pools + + bs, err := sync.NewMetaBootstrap(args) + + assert.True(t, check.IfNil(bs)) + assert.Equal(t, process.ErrNilProofsPool, err) +} + func TestNewMetaBootstrap_NilStoreShouldErr(t *testing.T) { t.Parallel() @@ -386,6 +399,34 @@ func TestNewMetaBootstrap_InvalidProcessTimeShouldErr(t *testing.T) { assert.True(t, errors.Is(err, process.ErrInvalidProcessWaitTime)) } +func TestNewMetaBootstrap_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + args := CreateMetaBootstrapMockArguments() + args.EnableEpochsHandler = nil + + bs, err := sync.NewMetaBootstrap(args) + + assert.True(t, check.IfNil(bs)) + assert.True(t, errors.Is(err, process.ErrNilEnableEpochsHandler)) +} + +func TestNewMetaBootstrap_PoolsHolderRetNilOnProofsShouldErr(t *testing.T) { + t.Parallel() + + args := CreateMetaBootstrapMockArguments() + pools := createMockPools() + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return nil + } + args.PoolsHolder = pools + + bs, err := sync.NewMetaBootstrap(args) + + assert.True(t, check.IfNil(bs)) + assert.Equal(t, process.ErrNilProofsPool, err) +} + func TestNewMetaBootstrap_MissingStorer(t *testing.T) { t.Parallel() @@ -652,7 +693,7 @@ func TestMetaBootstrap_ShouldReturnNilErr(t *testing.T) { return sds } pools.MiniBlocksCalled = func() storage.Cacher { - sds := &testscommon.CacherStub{ + sds := &cache.CacherStub{ HasOrAddCalled: func(key []byte, value interface{}, sizeInBytes int) (has, added bool) { return false, true }, @@ -919,6 +960,8 @@ func TestMetaBootstrap_GetNodeStateShouldReturnNotSynchronizedWhenForkIsDetected &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) bs, _ := sync.NewMetaBootstrap(args) @@ -984,6 +1027,8 @@ func TestMetaBootstrap_GetNodeStateShouldReturnSynchronizedWhenForkIsDetectedAnd &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) bs, _ := sync.NewMetaBootstrap(args) @@ -1075,118 +1120,95 @@ func TestMetaBootstrap_GetHeaderFromPoolShouldReturnHeader(t *testing.T) { // ------- testing received headers -func TestMetaBootstrap_ReceivedHeadersFoundInPoolShouldAddToForkDetector(t *testing.T) { +func TestMetaBootstrap_ReceivedHeaders(t *testing.T) { t.Parallel() - args := CreateMetaBootstrapMockArguments() + t.Run("should add to fork detector if header hash matches", func(t *testing.T) { + t.Parallel() - addedHash := []byte("hash") - addedHdr := &block.MetaBlock{} + args := CreateMetaBootstrapMockArguments() - pools := createMockPools() - pools.HeadersCalled = func() dataRetriever.HeadersPool { - sds := &mock.HeadersCacherStub{} - sds.RegisterHandlerCalled = func(func(header data.HeaderHandler, key []byte)) { - } - sds.GetHeaderByHashCalled = func(key []byte) (handler data.HeaderHandler, e error) { - if bytes.Equal(key, addedHash) { - return addedHdr, nil - } + addedHash := []byte("hash") + addedHdr := &block.MetaBlock{} - return nil, errors.New("err") - } + wasAdded := false - return sds - } - args.PoolsHolder = pools + forkDetector := &mock.ForkDetectorMock{} + forkDetector.AddHeaderCalled = func(header data.HeaderHandler, hash []byte, state process.BlockHeaderState, selfNotarizedHeaders []data.HeaderHandler, selfNotarizedHeadersHashes [][]byte) error { + if state == process.BHProcessed { + return errors.New("processed") + } - wasAdded := false + if !bytes.Equal(hash, addedHash) { + return errors.New("hash mismatch") + } - forkDetector := &mock.ForkDetectorMock{} - forkDetector.AddHeaderCalled = func(header data.HeaderHandler, hash []byte, state process.BlockHeaderState, selfNotarizedHeaders []data.HeaderHandler, selfNotarizedHeadersHashes [][]byte) error { - if state == process.BHProcessed { - return errors.New("processed") - } + if !reflect.DeepEqual(header, addedHdr) { + return errors.New("header mismatch") + } - if !bytes.Equal(hash, addedHash) { - return errors.New("hash mismatch") + wasAdded = true + return nil } - - if !reflect.DeepEqual(header, addedHdr) { - return errors.New("header mismatch") + forkDetector.ProbableHighestNonceCalled = func() uint64 { + return 0 } + args.ForkDetector = forkDetector - wasAdded = true - return nil - } - forkDetector.ProbableHighestNonceCalled = func() uint64 { - return 0 - } - args.ForkDetector = forkDetector + shardCoordinator := mock.NewMultipleShardsCoordinatorMock() + shardCoordinator.CurrentShard = core.MetachainShardId + args.ShardCoordinator = shardCoordinator + args.RoundHandler = initRoundHandler() - shardCoordinator := mock.NewMultipleShardsCoordinatorMock() - shardCoordinator.CurrentShard = core.MetachainShardId - args.ShardCoordinator = shardCoordinator - args.RoundHandler = initRoundHandler() - - bs, err := sync.NewMetaBootstrap(args) - require.Nil(t, err) - bs.ReceivedHeaders(addedHdr, addedHash) - time.Sleep(500 * time.Millisecond) - - assert.True(t, wasAdded) -} - -func TestMetaBootstrap_ReceivedHeadersNotFoundInPoolShouldNotAddToForkDetector(t *testing.T) { - t.Parallel() + bs, err := sync.NewMetaBootstrap(args) + require.Nil(t, err) + bs.ReceivedHeaders(addedHdr, addedHash) + time.Sleep(500 * time.Millisecond) - args := CreateMetaBootstrapMockArguments() + assert.True(t, wasAdded) + }) - addedHash := []byte("hash") - addedHdr := &block.MetaBlock{Nonce: 1} + t.Run("should not add to fork detector if header hash does not match", func(t *testing.T) { + t.Parallel() - wasAdded := false + args := CreateMetaBootstrapMockArguments() - forkDetector := &mock.ForkDetectorMock{} - forkDetector.AddHeaderCalled = func(header data.HeaderHandler, hash []byte, state process.BlockHeaderState, selfNotarizedHeaders []data.HeaderHandler, selfNotarizedHeadersHashes [][]byte) error { - if state == process.BHProcessed { - return errors.New("processed") - } + addedHash := []byte("hash") + addedHdr := &block.MetaBlock{} - if !bytes.Equal(hash, addedHash) { - return errors.New("hash mismatch") - } + wasAdded := false - if !reflect.DeepEqual(header, addedHdr) { - return errors.New("header mismatch") - } + forkDetector := &mock.ForkDetectorMock{} + forkDetector.AddHeaderCalled = func(header data.HeaderHandler, hash []byte, state process.BlockHeaderState, selfNotarizedHeaders []data.HeaderHandler, selfNotarizedHeadersHashes [][]byte) error { + if state == process.BHProcessed { + return errors.New("processed") + } - wasAdded = true - return nil - } - args.ForkDetector = forkDetector + if !bytes.Equal(hash, addedHash) { + return errors.New("hash mismatch") + } - headerStorage := &storageStubs.StorerStub{} - headerStorage.GetCalled = func(key []byte) (i []byte, e error) { - if bytes.Equal(key, addedHash) { - buff, _ := args.Marshalizer.Marshal(addedHdr) + if !reflect.DeepEqual(header, addedHdr) { + return errors.New("header mismatch") + } - return buff, nil + wasAdded = true + return nil } + args.ForkDetector = forkDetector - return nil, nil - } - args.Store = createMetaStore() - args.Store.AddStorer(dataRetriever.MetaBlockUnit, headerStorage) - args.ChainHandler, _ = blockchain.NewBlockChain(&statusHandlerMock.AppStatusHandlerStub{}) - args.RoundHandler = initRoundHandler() + shardCoordinator := mock.NewMultipleShardsCoordinatorMock() + shardCoordinator.CurrentShard = core.MetachainShardId + args.ShardCoordinator = shardCoordinator + args.RoundHandler = initRoundHandler() - bs, err := sync.NewMetaBootstrap(args) - require.Nil(t, err) - bs.ReceivedHeaders(addedHdr, addedHash) - time.Sleep(500 * time.Millisecond) + bs, err := sync.NewMetaBootstrap(args) + require.Nil(t, err) + bs.ReceivedHeaders(addedHdr, []byte("otherHash")) + time.Sleep(500 * time.Millisecond) - assert.False(t, wasAdded) + assert.False(t, wasAdded) + }) } // ------- RollBack @@ -1615,6 +1637,7 @@ func TestMetaBootstrap_SyncBlockErrGetNodeDBShouldSyncAccounts(t *testing.T) { t.Parallel() args := CreateMetaBootstrapMockArguments() + hdr := block.MetaBlock{Nonce: 1, PubKeysBitmap: []byte("X")} blkc := &testscommon.ChainHandlerStub{ GetCurrentBlockHeaderCalled: func() data.HeaderHandler { @@ -1653,6 +1676,7 @@ func TestMetaBootstrap_SyncBlockErrGetNodeDBShouldSyncAccounts(t *testing.T) { return sds } + args.PoolsHolder = pools forkDetector := &mock.ForkDetectorMock{} @@ -1727,6 +1751,273 @@ func TestMetaBootstrap_SyncBlockErrGetNodeDBShouldSyncAccounts(t *testing.T) { assert.True(t, accountsSyncCalled) } +func TestMetaBootstrap_SyncBlock_WithEquivalentProofs(t *testing.T) { + t.Parallel() + + t.Run("time is out when existing header and missing proof", func(t *testing.T) { + t.Parallel() + + args := CreateMetaBootstrapMockArguments() + + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AndromedaFlag + }, + } + + hdr := block.MetaBlock{Nonce: 1} + blkc := &testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.Header{} + }, + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return &hdr + }, + } + args.ChainHandler = blkc + + forkDetector := &mock.ForkDetectorMock{} + forkDetector.CheckForkCalled = func() *process.ForkInfo { + return process.NewForkInfo() + } + forkDetector.ProbableHighestNonceCalled = func() uint64 { + return 100 + } + forkDetector.GetNotarizedHeaderHashCalled = func(nonce uint64) []byte { + return nil + } + args.ForkDetector = forkDetector + args.RoundHandler, _ = round.NewRound(time.Now(), + time.Now().Add(2*100*time.Millisecond), + 100*time.Millisecond, + &mock.SyncTimerMock{}, + 0, + ) + args.BlockProcessor = createMetaBlockProcessor(args.ChainHandler) + + pools := createMockPools() + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + GetProofByNonceCalled: func(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) { + return nil, errors.New("missing proof") + }, + } + } + + args.PoolsHolder = pools + + bs, _ := sync.NewMetaBootstrap(args) + r := bs.SyncBlock(context.Background()) + + assert.Equal(t, process.ErrTimeIsOut, r) + }) + + t.Run("should receive header and proof if missing, requesting by nonce", func(t *testing.T) { + t.Parallel() + + args := CreateMetaBootstrapMockArguments() + + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + + hdr := block.MetaBlock{Nonce: 1} + blkc := &testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.Header{} + }, + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return &hdr + }, + } + args.ChainHandler = blkc + + forkDetector := &mock.ForkDetectorMock{} + forkDetector.CheckForkCalled = func() *process.ForkInfo { + return process.NewForkInfo() + } + forkDetector.ProbableHighestNonceCalled = func() uint64 { + return 100 + } + forkDetector.GetNotarizedHeaderHashCalled = func(nonce uint64) []byte { + return nil + } + args.ForkDetector = forkDetector + args.RoundHandler, _ = round.NewRound(time.Now(), + time.Now().Add(2*100*time.Millisecond), + 100*time.Millisecond, + &mock.SyncTimerMock{}, + 0, + ) + args.BlockProcessor = createMetaBlockProcessor(args.ChainHandler) + + pools := createMockPools() + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + GetProofByNonceCalled: func(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) { + return nil, errors.New("missing proof") + }, + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true // second check after wait is done by hash + }, + } + } + + numHeaderCalls := 0 + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + sds.GetHeaderByNonceAndShardIdCalled = func(hdrNonce uint64, shardId uint32) (handlers []data.HeaderHandler, i [][]byte, e error) { + if numHeaderCalls == 0 { + numHeaderCalls++ + return nil, nil, errors.New("err") + } + + return []data.HeaderHandler{ + &block.MetaBlock{ + Nonce: 1, + Round: 1, + RootHash: []byte("bbb")}, + }, [][]byte{[]byte("aaa")}, nil + } + + return sds + } + args.PoolsHolder = pools + + receive := make(chan bool, 2) + + args.RequestHandler = &testscommon.RequestHandlerStub{ + RequestMetaHeaderByNonceCalled: func(nonce uint64) { + receive <- true + }, + RequestEquivalentProofByNonceCalled: func(headerShard uint32, headerNonce uint64) { + receive <- true + }, + } + + bs, _ := sync.NewMetaBootstrap(args) + + go func() { + // wait for both header and proof requests + <-receive + <-receive + + bs.SetRcvHdrNonce() + }() + + err := bs.SyncBlock(context.Background()) + + assert.Nil(t, err) + }) + + t.Run("should receive header and proof if missing, requesting by hash", func(t *testing.T) { + t.Parallel() + + args := CreateMetaBootstrapMockArguments() + + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AndromedaFlag + }, + } + + hdr := block.MetaBlock{Nonce: 1} + blkc := &testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.Header{} + }, + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return &hdr + }, + } + args.ChainHandler = blkc + + forkDetector := &mock.ForkDetectorMock{} + forkDetector.CheckForkCalled = func() *process.ForkInfo { + return process.NewForkInfo() + } + forkDetector.ProbableHighestNonceCalled = func() uint64 { + return 100 + } + + hash := []byte("hash1") + forkDetector.GetNotarizedHeaderHashCalled = func(nonce uint64) []byte { + return hash + } + args.ForkDetector = forkDetector + args.RoundHandler, _ = round.NewRound(time.Now(), + time.Now().Add(2*100*time.Millisecond), + 100*time.Millisecond, + &mock.SyncTimerMock{}, + 0, + ) + args.BlockProcessor = createMetaBlockProcessor(args.ChainHandler) + + pools := createMockPools() + + numProofCalls := 0 + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + return nil, errors.New("missing proof") + }, + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + if numProofCalls == 0 { + numProofCalls++ + return false + } + + return true // second check after wait is done by hash + }, + } + } + + numHeaderCalls := 0 + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + + sds.GetHeaderByHashCalled = func(hash []byte) (data.HeaderHandler, error) { + if numHeaderCalls == 0 { + numHeaderCalls++ + return nil, errors.New("err") + } + + return &block.MetaBlock{}, nil + } + + return sds + } + args.PoolsHolder = pools + + receive := make(chan bool, 2) + + args.RequestHandler = &testscommon.RequestHandlerStub{ + RequestMetaHeaderCalled: func(hash []byte) { + receive <- true + }, + RequestEquivalentProofByHashCalled: func(headerShard uint32, headerHash []byte) { + receive <- true + }, + } + + bs, _ := sync.NewMetaBootstrap(args) + + go func() { + // wait for both header and proof requests + <-receive + <-receive + + bs.SetRcvHdrHash() + }() + + err := bs.SyncBlock(context.Background()) + + assert.Nil(t, err) + }) +} + func TestMetaBootstrap_SyncAccountsDBs(t *testing.T) { t.Parallel() diff --git a/process/sync/shardForkDetector.go b/process/sync/shardForkDetector.go index 52715f36163..a45ed8bd77b 100644 --- a/process/sync/shardForkDetector.go +++ b/process/sync/shardForkDetector.go @@ -7,6 +7,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/process" ) @@ -24,6 +26,8 @@ func NewShardForkDetector( blackListHandler process.TimeCacher, blockTracker process.BlockTracker, genesisTime int64, + enableEpochsHandler common.EnableEpochsHandler, + proofsPool process.ProofsPool, ) (*shardForkDetector, error) { if check.IfNil(roundHandler) { @@ -35,6 +39,12 @@ func NewShardForkDetector( if check.IfNil(blockTracker) { return nil, process.ErrNilBlockTracker } + if check.IfNil(enableEpochsHandler) { + return nil, process.ErrNilEnableEpochsHandler + } + if check.IfNil(proofsPool) { + return nil, process.ErrNilProofsPool + } genesisHdr, _, err := blockTracker.GetSelfNotarizedHeader(core.MetachainShardId, 0) if err != nil { @@ -42,13 +52,15 @@ func NewShardForkDetector( } bfd := &baseForkDetector{ - roundHandler: roundHandler, - blackListHandler: blackListHandler, - genesisTime: genesisTime, - blockTracker: blockTracker, - genesisNonce: genesisHdr.GetNonce(), - genesisRound: genesisHdr.GetRound(), - genesisEpoch: genesisHdr.GetEpoch(), + roundHandler: roundHandler, + blackListHandler: blackListHandler, + genesisTime: genesisTime, + blockTracker: blockTracker, + genesisNonce: genesisHdr.GetNonce(), + genesisRound: genesisHdr.GetRound(), + genesisEpoch: genesisHdr.GetEpoch(), + enableEpochsHandler: enableEpochsHandler, + proofsPool: proofsPool, } bfd.headers = make(map[uint64][]*headerInfo) @@ -100,7 +112,13 @@ func (sfd *shardForkDetector) doJobOnBHProcessed( ) { _ = sfd.appendSelfNotarizedHeaders(selfNotarizedHeaders, selfNotarizedHeadersHashes, core.MetachainShardId) sfd.computeFinalCheckpoint() - sfd.addCheckpoint(&checkpointInfo{nonce: header.GetNonce(), round: header.GetRound(), hash: headerHash}) + newCheckpoint := &checkpointInfo{nonce: header.GetNonce(), round: header.GetRound(), hash: headerHash} + sfd.addCheckpoint(newCheckpoint) + // first shard block with proof does not have increased consensus + // so instant finality will only be set after the first block with increased consensus + if common.IsFlagEnabledAfterEpochsStartBlock(header, sfd.enableEpochsHandler, common.AndromedaFlag) { + sfd.setFinalCheckpoint(newCheckpoint) + } sfd.removePastOrInvalidRecords() } @@ -136,11 +154,13 @@ func (sfd *shardForkDetector) appendSelfNotarizedHeaders( continue } + hasProof := sfd.proofsPool.HasProof(selfNotarizedHeaders[i].GetShardID(), selfNotarizedHeadersHashes[i]) appended := sfd.append(&headerInfo{ - nonce: selfNotarizedHeaders[i].GetNonce(), - round: selfNotarizedHeaders[i].GetRound(), - hash: selfNotarizedHeadersHashes[i], - state: process.BHNotarized, + nonce: selfNotarizedHeaders[i].GetNonce(), + round: selfNotarizedHeaders[i].GetRound(), + hash: selfNotarizedHeadersHashes[i], + state: process.BHNotarized, + hasProof: hasProof, }) if appended { log.Debug("added self notarized header in fork detector", diff --git a/process/sync/shardForkDetector_test.go b/process/sync/shardForkDetector_test.go index 98412430e71..d3b37a0dfd1 100644 --- a/process/sync/shardForkDetector_test.go +++ b/process/sync/shardForkDetector_test.go @@ -5,10 +5,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/sync" "github.com/multiversx/mx-chain-go/testscommon" + dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/stretchr/testify/assert" ) @@ -20,6 +24,8 @@ func TestNewShardForkDetector_NilRoundHandlerShouldErr(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) assert.True(t, check.IfNil(sfd)) assert.Equal(t, process.ErrNilRoundHandler, err) @@ -33,6 +39,8 @@ func TestNewShardForkDetector_NilBlackListShouldErr(t *testing.T) { nil, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) assert.True(t, check.IfNil(sfd)) assert.Equal(t, process.ErrNilBlackListCacher, err) @@ -46,11 +54,43 @@ func TestNewShardForkDetector_NilBlockTrackerShouldErr(t *testing.T) { &testscommon.TimeCacheStub{}, nil, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) assert.True(t, check.IfNil(sfd)) assert.Equal(t, process.ErrNilBlockTracker, err) } +func TestNewShardForkDetector_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + sfd, err := sync.NewShardForkDetector( + &mock.RoundHandlerMock{}, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + nil, + &dataRetrieverMock.ProofsPoolMock{}, + ) + assert.True(t, check.IfNil(sfd)) + assert.Equal(t, process.ErrNilEnableEpochsHandler, err) +} + +func TestNewShardForkDetector_NilProofsPoolShouldErr(t *testing.T) { + t.Parallel() + + sfd, err := sync.NewShardForkDetector( + &mock.RoundHandlerMock{}, + &testscommon.TimeCacheStub{}, + &mock.BlockTrackerMock{}, + 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + nil, + ) + assert.True(t, check.IfNil(sfd)) + assert.Equal(t, process.ErrNilProofsPool, err) +} + func TestNewShardForkDetector_OkParamsShouldWork(t *testing.T) { t.Parallel() @@ -59,6 +99,8 @@ func TestNewShardForkDetector_OkParamsShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) assert.Nil(t, err) assert.False(t, check.IfNil(sfd)) @@ -78,6 +120,8 @@ func TestShardForkDetector_AddHeaderNilHeaderShouldErr(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) err := bfd.AddHeader(nil, make([]byte, 0), process.BHProcessed, nil, nil) assert.Equal(t, sync.ErrNilHeader, err) @@ -92,6 +136,8 @@ func TestShardForkDetector_AddHeaderNilHashShouldErr(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) err := bfd.AddHeader(&block.Header{}, nil, process.BHProcessed, nil, nil) assert.Equal(t, sync.ErrNilHash, err) @@ -108,6 +154,8 @@ func TestShardForkDetector_AddHeaderNotPresentShouldWork(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) err := bfd.AddHeader(hdr, hash, process.BHProcessed, nil, nil) assert.Nil(t, err) @@ -130,6 +178,8 @@ func TestShardForkDetector_AddHeaderPresentShouldAppend(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) _ = bfd.AddHeader(hdr1, hash1, process.BHProcessed, nil, nil) err := bfd.AddHeader(hdr2, hash2, process.BHProcessed, nil, nil) @@ -152,6 +202,8 @@ func TestShardForkDetector_AddHeaderWithProcessedBlockShouldSetCheckpoint(t *tes &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) _ = bfd.AddHeader(hdr1, hash1, process.BHProcessed, nil, nil) assert.Equal(t, hdr1.Nonce, bfd.LastCheckpointNonce()) @@ -169,6 +221,8 @@ func TestShardForkDetector_AddHeaderPresentShouldNotRewriteState(t *testing.T) { &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) _ = bfd.AddHeader(hdr1, hash, process.BHReceived, nil, nil) err := bfd.AddHeader(hdr2, hash, process.BHProcessed, nil, nil) @@ -190,6 +244,8 @@ func TestShardForkDetector_AddHeaderHigherNonceThanRoundShouldErr(t *testing.T) &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) err := bfd.AddHeader( &block.Header{Nonce: 1, Round: 0, PubKeysBitmap: []byte("X")}, []byte("hash1"), process.BHProcessed, nil, nil) diff --git a/process/sync/shardblock.go b/process/sync/shardblock.go index 8cca3954ef0..1986a394ff6 100644 --- a/process/sync/shardblock.go +++ b/process/sync/shardblock.go @@ -27,6 +27,9 @@ func NewShardBootstrap(arguments ArgShardBootstrapper) (*ShardBootstrap, error) if check.IfNil(arguments.PoolsHolder.Headers()) { return nil, process.ErrNilHeadersDataPool } + if check.IfNil(arguments.PoolsHolder.Proofs()) { + return nil, process.ErrNilProofsPool + } if check.IfNil(arguments.PoolsHolder.MiniBlocks()) { return nil, process.ErrNilTxBlockBody } @@ -41,6 +44,7 @@ func NewShardBootstrap(arguments ArgShardBootstrapper) (*ShardBootstrap, error) blockProcessor: arguments.BlockProcessor, store: arguments.Store, headers: arguments.PoolsHolder.Headers(), + proofs: arguments.PoolsHolder.Proofs(), roundHandler: arguments.RoundHandler, waitTime: arguments.WaitTime, hasher: arguments.Hasher, @@ -66,6 +70,7 @@ func NewShardBootstrap(arguments ArgShardBootstrapper) (*ShardBootstrap, error) scheduledTxsExecutionHandler: arguments.ScheduledTxsExecutionHandler, processWaitTime: arguments.ProcessWaitTime, repopulateTokensSupplies: arguments.RepopulateTokensSupplies, + enableEpochsHandler: arguments.EnableEpochsHandler, } if base.isInImportMode { @@ -78,7 +83,6 @@ func NewShardBootstrap(arguments ArgShardBootstrapper) (*ShardBootstrap, error) base.blockBootstrapper = &boot base.syncStarter = &boot - base.getHeaderFromPool = boot.getShardHeaderFromPool base.requestMiniBlocks = boot.requestMiniBlocksFromHeaderWithNonceIfMissing // placed in struct fields for performance reasons @@ -87,7 +91,7 @@ func NewShardBootstrap(arguments ArgShardBootstrapper) (*ShardBootstrap, error) return nil, err } - hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(boot.shardCoordinator.SelfId()) + hdrNonceHashDataUnit := dataRetriever.GetHdrNonceHashDataUnit(boot.shardCoordinator.SelfId()) base.headerNonceHashStore, err = boot.store.GetStorer(hdrNonceHashDataUnit) if err != nil { return nil, err @@ -174,74 +178,6 @@ func (boot *ShardBootstrap) Close() error { return boot.baseBootstrap.Close() } -// requestHeaderWithNonce method requests a block header from network when it is not found in the pool -func (boot *ShardBootstrap) requestHeaderWithNonce(nonce uint64) { - boot.setRequestedHeaderNonce(&nonce) - log.Debug("requesting shard header from network", - "nonce", nonce, - "probable highest nonce", boot.forkDetector.ProbableHighestNonce(), - ) - boot.requestHandler.RequestShardHeaderByNonce(boot.shardCoordinator.SelfId(), nonce) -} - -// requestHeaderWithHash method requests a block header from network when it is not found in the pool -func (boot *ShardBootstrap) requestHeaderWithHash(hash []byte) { - boot.setRequestedHeaderHash(hash) - log.Debug("requesting shard header from network", - "hash", hash, - "probable highest nonce", boot.forkDetector.ProbableHighestNonce(), - ) - boot.requestHandler.RequestShardHeader(boot.shardCoordinator.SelfId(), hash) -} - -// getHeaderWithNonceRequestingIfMissing method gets the header with a given nonce from pool. If it is not found there, it will -// be requested from network -func (boot *ShardBootstrap) getHeaderWithNonceRequestingIfMissing(nonce uint64) (data.HeaderHandler, error) { - hdr, _, err := process.GetShardHeaderFromPoolWithNonce( - nonce, - boot.shardCoordinator.SelfId(), - boot.headers) - if err != nil { - _ = core.EmptyChannel(boot.chRcvHdrNonce) - boot.requestHeaderWithNonce(nonce) - err = boot.waitForHeaderNonce() - if err != nil { - return nil, err - } - - hdr, _, err = process.GetShardHeaderFromPoolWithNonce( - nonce, - boot.shardCoordinator.SelfId(), - boot.headers) - if err != nil { - return nil, err - } - } - - return hdr, nil -} - -// getHeaderWithHashRequestingIfMissing method gets the header with a given hash from pool. If it is not found there, -// it will be requested from network -func (boot *ShardBootstrap) getHeaderWithHashRequestingIfMissing(hash []byte) (data.HeaderHandler, error) { - hdr, err := process.GetShardHeader(hash, boot.headers, boot.marshalizer, boot.store) - if err != nil { - _ = core.EmptyChannel(boot.chRcvHdrHash) - boot.requestHeaderWithHash(hash) - err = boot.waitForHeaderHash() - if err != nil { - return nil, err - } - - hdr, err = process.GetShardHeaderFromPool(hash, boot.headers) - if err != nil { - return nil, err - } - } - - return hdr, nil -} - func (boot *ShardBootstrap) getPrevHeader( header data.HeaderHandler, headerStore storage.Storer, @@ -275,19 +211,6 @@ func (boot *ShardBootstrap) getCurrHeader() (data.HeaderHandler, error) { return header, nil } -func (boot *ShardBootstrap) haveHeaderInPoolWithNonce(nonce uint64) bool { - _, _, err := process.GetShardHeaderFromPoolWithNonce( - nonce, - boot.shardCoordinator.SelfId(), - boot.headers) - - return err == nil -} - -func (boot *ShardBootstrap) getShardHeaderFromPool(headerHash []byte) (data.HeaderHandler, error) { - return process.GetShardHeaderFromPool(headerHash, boot.headers) -} - func (boot *ShardBootstrap) requestMiniBlocksFromHeaderWithNonceIfMissing(headerHandler data.HeaderHandler) { nextBlockNonce := boot.getNonceForNextBlock() maxNonce := core.MinUint64(nextBlockNonce+process.MaxHeadersToRequestInAdvance-1, boot.forkDetector.ProbableHighestNonce()) @@ -350,6 +273,10 @@ func (boot *ShardBootstrap) requestHeaderByNonce(nonce uint64) { boot.requestHandler.RequestShardHeaderByNonce(boot.shardCoordinator.SelfId(), nonce) } +func (boot *ShardBootstrap) requestProofByNonce(nonce uint64) { + boot.requestHandler.RequestEquivalentProofByNonce(boot.shardCoordinator.SelfId(), nonce) +} + // IsInterfaceNil returns true if there is no value under the interface func (boot *ShardBootstrap) IsInterfaceNil() bool { return boot == nil diff --git a/process/sync/shardblock_test.go b/process/sync/shardblock_test.go index 070b926df0f..cba152bca5c 100644 --- a/process/sync/shardblock_test.go +++ b/process/sync/shardblock_test.go @@ -16,6 +16,9 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/consensus/round" @@ -28,15 +31,15 @@ import ( "github.com/multiversx/mx-chain-go/storage/database" "github.com/multiversx/mx-chain-go/storage/storageunit" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/dblookupext" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/outport" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) // waitTime defines the time in milliseconds until node waits the requested info from the network @@ -55,7 +58,7 @@ func createMockPools() *dataRetrieverMock.PoolsHolderStub { return &mock.HeadersCacherStub{} } pools.MiniBlocksCalled = func() storage.Cacher { - cs := &testscommon.CacherStub{ + cs := &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -63,6 +66,9 @@ func createMockPools() *dataRetrieverMock.PoolsHolderStub { } return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } return pools } @@ -219,6 +225,7 @@ func CreateShardBootstrapMockArguments() sync.ArgShardBootstrapper { ScheduledTxsExecutionHandler: &testscommon.ScheduledTxsExecutionStub{}, ProcessWaitTime: testProcessWaitTime, RepopulateTokensSupplies: false, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } argsShardBootstrapper := sync.ArgShardBootstrapper{ @@ -270,6 +277,22 @@ func TestNewShardBootstrap_PoolsHolderRetNilOnHeadersShouldErr(t *testing.T) { assert.Equal(t, process.ErrNilHeadersDataPool, err) } +func TestNewShardBootstrap_NilProofsPool(t *testing.T) { + t.Parallel() + + args := CreateShardBootstrapMockArguments() + pools := createMockPools() + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return nil + } + args.PoolsHolder = pools + + bs, err := sync.NewShardBootstrap(args) + + assert.True(t, check.IfNil(bs)) + assert.Equal(t, process.ErrNilProofsPool, err) +} + func TestNewShardBootstrap_PoolsHolderRetNilOnTxBlockBodyShouldErr(t *testing.T) { t.Parallel() @@ -442,6 +465,34 @@ func TestNewShardBootstrap_InvalidProcessTimeShouldErr(t *testing.T) { assert.True(t, errors.Is(err, process.ErrInvalidProcessWaitTime)) } +func TestNewShardBootstrap_NilEnableEpochsHandlerShouldErr(t *testing.T) { + t.Parallel() + + args := CreateShardBootstrapMockArguments() + args.EnableEpochsHandler = nil + + bs, err := sync.NewShardBootstrap(args) + + assert.True(t, check.IfNil(bs)) + assert.True(t, errors.Is(err, process.ErrNilEnableEpochsHandler)) +} + +func TestNewShardBootstrap_PoolsHolderRetNilOnProofsShouldErr(t *testing.T) { + t.Parallel() + + args := CreateShardBootstrapMockArguments() + pools := createMockPools() + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return nil + } + args.PoolsHolder = pools + + bs, err := sync.NewShardBootstrap(args) + + assert.True(t, check.IfNil(bs)) + assert.Equal(t, process.ErrNilProofsPool, err) +} + func TestNewShardBootstrap_MissingStorer(t *testing.T) { t.Parallel() @@ -491,13 +542,17 @@ func TestNewShardBootstrap_OkValsShouldWork(t *testing.T) { return sds } pools.MiniBlocksCalled = func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { wasCalled++ } return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } + args.PoolsHolder = pools args.IsInImportMode = true bs, err := sync.NewShardBootstrap(args) @@ -708,7 +763,7 @@ func TestBootstrap_SyncShouldSyncOneBlock(t *testing.T) { return sds } pools.MiniBlocksCalled = func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -721,6 +776,10 @@ func TestBootstrap_SyncShouldSyncOneBlock(t *testing.T) { return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } + args.PoolsHolder = pools forkDetector := &mock.ForkDetectorMock{} @@ -803,7 +862,7 @@ func TestBootstrap_ShouldReturnNilErr(t *testing.T) { return sds } pools.MiniBlocksCalled = func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -816,6 +875,9 @@ func TestBootstrap_ShouldReturnNilErr(t *testing.T) { return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } args.PoolsHolder = pools forkDetector := &mock.ForkDetectorMock{} @@ -885,7 +947,7 @@ func TestBootstrap_SyncBlockShouldReturnErrorWhenProcessBlockFailed(t *testing.T return sds } pools.MiniBlocksCalled = func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -898,6 +960,9 @@ func TestBootstrap_SyncBlockShouldReturnErrorWhenProcessBlockFailed(t *testing.T return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } args.PoolsHolder = pools forkDetector := &mock.ForkDetectorMock{} @@ -1103,6 +1168,8 @@ func TestBootstrap_GetNodeStateShouldReturnNotSynchronizedWhenForkIsDetectedAndI &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) bs, _ := sync.NewShardBootstrap(args) @@ -1178,6 +1245,8 @@ func TestBootstrap_GetNodeStateShouldReturnSynchronizedWhenForkIsDetectedAndItRe &testscommon.TimeCacheStub{}, &mock.BlockTrackerMock{}, 0, + &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + &dataRetrieverMock.ProofsPoolMock{}, ) bs, _ := sync.NewShardBootstrap(args) @@ -1874,12 +1943,15 @@ func TestShardBootstrap_RequestMiniBlocksFromHeaderWithNonceIfMissing(t *testing return sds } pools.MiniBlocksCalled = func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } args.PoolsHolder = pools blkc := initBlockchain() @@ -2093,7 +2165,7 @@ func TestShardBootstrap_SyncBlockGetNodeDBErrorShouldSync(t *testing.T) { return sds } pools.MiniBlocksCalled = func() storage.Cacher { - cs := testscommon.NewCacherStub() + cs := cache.NewCacherStub() cs.RegisterHandlerCalled = func(i func(key []byte, value interface{})) { } cs.GetCalled = func(key []byte) (value interface{}, ok bool) { @@ -2106,6 +2178,9 @@ func TestShardBootstrap_SyncBlockGetNodeDBErrorShouldSync(t *testing.T) { return cs } + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + } args.PoolsHolder = pools forkDetector := &mock.ForkDetectorMock{} @@ -2144,13 +2219,281 @@ func TestShardBootstrap_SyncBlockGetNodeDBErrorShouldSync(t *testing.T) { return []byte("roothash"), nil }} - bs, _ := sync.NewShardBootstrap(args) + bs, err := sync.NewShardBootstrap(args) + require.Nil(t, err) - err := bs.SyncBlock(context.Background()) + err = bs.SyncBlock(context.Background()) assert.Equal(t, errGetNodeFromDB, err) assert.True(t, syncCalled) } +func TestShardBootstrap_SyncBlock_WithEquivalentProofs(t *testing.T) { + t.Parallel() + + t.Run("time is out when existing header and missing proof", func(t *testing.T) { + t.Parallel() + + args := CreateShardBootstrapMockArguments() + + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AndromedaFlag + }, + } + + hdr := block.Header{Nonce: 1} + blkc := &testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.Header{} + }, + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return &hdr + }, + } + args.ChainHandler = blkc + + forkDetector := &mock.ForkDetectorMock{} + forkDetector.CheckForkCalled = func() *process.ForkInfo { + return process.NewForkInfo() + } + forkDetector.ProbableHighestNonceCalled = func() uint64 { + return 100 + } + forkDetector.GetNotarizedHeaderHashCalled = func(nonce uint64) []byte { + return nil + } + args.ForkDetector = forkDetector + args.RoundHandler, _ = round.NewRound(time.Now(), + time.Now().Add(2*100*time.Millisecond), + 100*time.Millisecond, + &mock.SyncTimerMock{}, + 0, + ) + args.BlockProcessor = createBlockProcessor(args.ChainHandler) + + pools := createMockPools() + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + GetProofByNonceCalled: func(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) { + return nil, errors.New("missing proof") + }, + } + } + + args.PoolsHolder = pools + + bs, _ := sync.NewShardBootstrap(args) + r := bs.SyncBlock(context.Background()) + + assert.Equal(t, process.ErrTimeIsOut, r) + }) + + t.Run("should receive header and proof if missing, requesting by nonce", func(t *testing.T) { + t.Parallel() + + args := CreateShardBootstrapMockArguments() + + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + } + + hdr := block.Header{Nonce: 1} + blkc := &testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.Header{} + }, + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return &hdr + }, + } + args.ChainHandler = blkc + + forkDetector := &mock.ForkDetectorMock{} + forkDetector.CheckForkCalled = func() *process.ForkInfo { + return process.NewForkInfo() + } + forkDetector.ProbableHighestNonceCalled = func() uint64 { + return 100 + } + forkDetector.GetNotarizedHeaderHashCalled = func(nonce uint64) []byte { + return nil + } + args.ForkDetector = forkDetector + args.RoundHandler, _ = round.NewRound(time.Now(), + time.Now().Add(2*100*time.Millisecond), + 100*time.Millisecond, + &mock.SyncTimerMock{}, + 0, + ) + args.BlockProcessor = createBlockProcessor(args.ChainHandler) + + pools := createMockPools() + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + GetProofByNonceCalled: func(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) { + return nil, errors.New("missing proof") + }, + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + return true // second check after wait is done by hash + }, + } + } + + numHeaderCalls := 0 + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + sds.GetHeaderByNonceAndShardIdCalled = func(hdrNonce uint64, shardId uint32) (handlers []data.HeaderHandler, i [][]byte, e error) { + if numHeaderCalls == 0 { + numHeaderCalls++ + return nil, nil, errors.New("err") + } + + return []data.HeaderHandler{ + &block.Header{ + Nonce: 1, + Round: 1, + RootHash: []byte("bbb")}, + }, [][]byte{[]byte("aaa")}, nil + } + + return sds + } + args.PoolsHolder = pools + + receive := make(chan bool, 2) + + args.RequestHandler = &testscommon.RequestHandlerStub{ + RequestShardHeaderByNonceCalled: func(shardID uint32, nonce uint64) { + receive <- true + }, + RequestEquivalentProofByNonceCalled: func(headerShard uint32, headerNonce uint64) { + receive <- true + }, + } + + bs, _ := sync.NewShardBootstrap(args) + + go func() { + // wait for both header and proof requests + <-receive + <-receive + + bs.SetRcvHdrNonce() + }() + + err := bs.SyncBlock(context.Background()) + + assert.Nil(t, err) + }) + + t.Run("should receive header and proof if missing, requesting by hash", func(t *testing.T) { + t.Parallel() + + args := CreateShardBootstrapMockArguments() + + args.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AndromedaFlag + }, + } + + hdr := block.Header{Nonce: 1} + blkc := &testscommon.ChainHandlerStub{ + GetGenesisHeaderCalled: func() data.HeaderHandler { + return &block.Header{} + }, + GetCurrentBlockHeaderCalled: func() data.HeaderHandler { + return &hdr + }, + } + args.ChainHandler = blkc + + forkDetector := &mock.ForkDetectorMock{} + forkDetector.CheckForkCalled = func() *process.ForkInfo { + return process.NewForkInfo() + } + forkDetector.ProbableHighestNonceCalled = func() uint64 { + return 100 + } + + hash := []byte("hash1") + forkDetector.GetNotarizedHeaderHashCalled = func(nonce uint64) []byte { + return hash + } + args.ForkDetector = forkDetector + args.RoundHandler, _ = round.NewRound(time.Now(), + time.Now().Add(2*100*time.Millisecond), + 100*time.Millisecond, + &mock.SyncTimerMock{}, + 0, + ) + args.BlockProcessor = createBlockProcessor(args.ChainHandler) + + pools := createMockPools() + + numProofCalls := 0 + pools.ProofsCalled = func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{ + GetProofCalled: func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + return nil, errors.New("missing proof") + }, + HasProofCalled: func(shardID uint32, headerHash []byte) bool { + if numProofCalls == 0 { + numProofCalls++ + return false + } + + return true // second check after wait is done by hash + }, + } + } + + numHeaderCalls := 0 + pools.HeadersCalled = func() dataRetriever.HeadersPool { + sds := &mock.HeadersCacherStub{} + + sds.GetHeaderByHashCalled = func(hash []byte) (data.HeaderHandler, error) { + if numHeaderCalls == 0 { + numHeaderCalls++ + return nil, errors.New("err") + } + + return &block.Header{}, nil + } + + return sds + } + args.PoolsHolder = pools + + receive := make(chan bool, 2) + + args.RequestHandler = &testscommon.RequestHandlerStub{ + RequestShardHeaderCalled: func(shardID uint32, hash []byte) { + receive <- true + }, + RequestEquivalentProofByHashCalled: func(headerShard uint32, headerHash []byte) { + receive <- true + }, + } + + bs, _ := sync.NewShardBootstrap(args) + + go func() { + // wait for both header and proof requests + <-receive + <-receive + + bs.SetRcvHdrHash() + }() + + err := bs.SyncBlock(context.Background()) + + assert.Nil(t, err) + }) +} + func TestShardBootstrap_NilInnerBootstrapperClose(t *testing.T) { t.Parallel() diff --git a/process/sync/storageBootstrap/baseStorageBootstrapper.go b/process/sync/storageBootstrap/baseStorageBootstrapper.go index a1326ac5f65..d42a9456f3d 100644 --- a/process/sync/storageBootstrap/baseStorageBootstrapper.go +++ b/process/sync/storageBootstrap/baseStorageBootstrapper.go @@ -9,6 +9,9 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/bootstrapStorage" @@ -17,7 +20,6 @@ import ( "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/storage" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("process/sync") @@ -44,6 +46,8 @@ type ArgsBaseStorageBootstrapper struct { EpochNotifier process.EpochNotifier ProcessedMiniBlocksTracker process.ProcessedMiniBlocksTracker AppStatusHandler core.AppStatusHandler + EnableEpochsHandler common.EnableEpochsHandler + ProofsPool process.ProofsPool } // ArgsShardStorageBootstrapper is structure used to create a new storage bootstrapper for shard @@ -79,6 +83,8 @@ type storageBootstrapper struct { epochNotifier process.EpochNotifier processedMiniBlocksTracker process.ProcessedMiniBlocksTracker appStatusHandler core.AppStatusHandler + enableEpochsHandler common.EnableEpochsHandler + proofsPool process.ProofsPool } func (st *storageBootstrapper) loadBlocks() error { @@ -292,6 +298,12 @@ func (st *storageBootstrapper) applyHeaderInfo(hdrInfo bootstrapStorage.Bootstra return err } + err = st.getAndApplyProofForHeader(headerHash, headerFromStorage) + if err != nil { + log.Debug("cannot apply proof for header ", "nonce", headerFromStorage.GetNonce(), "error", err.Error()) + return err + } + return nil } @@ -367,6 +379,12 @@ func (st *storageBootstrapper) applyBootInfos(bootInfos []bootstrapStorage.Boots return err } + err = st.getAndApplyProofForHeader(bootInfos[i].LastHeader.Hash, header) + if err != nil { + log.Debug("cannot get and apply header proof", "hash", bootInfos[i].LastHeader.Hash, "error", err.Error()) + return err + } + log.Debug("add header to fork detector", "shard", header.GetShardID(), "round", header.GetRound(), @@ -446,6 +464,44 @@ func (st *storageBootstrapper) applyBlock(headerHash []byte, header data.HeaderH st.blkc.SetCurrentBlockHeaderHash(headerHash) + if !st.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + return nil + } + + isFlagEnabledAfterEpochStart := common.IsFlagEnabledAfterEpochsStartBlock(header, st.enableEpochsHandler, common.AndromedaFlag) + + st.forkDetector.AddCheckpoint(header.GetNonce(), header.GetRound(), headerHash) + if header.GetShardID() == core.MetachainShardId || isFlagEnabledAfterEpochStart { + st.forkDetector.SetFinalToLastCheckpoint() + st.forkDetector.ResetProbableHighestNonce() + } + + return nil +} + +func (st *storageBootstrapper) getAndApplyProofForHeader(headerHash []byte, header data.HeaderHandler) error { + if !st.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + return nil + } + + proofsStorer, err := st.store.GetStorer(dataRetriever.ProofsUnit) + if err != nil { + return err + } + + marshaledProof, err := proofsStorer.SearchFirst(headerHash) + if err != nil { + return err + } + + proof := &block.HeaderProof{} + err = st.marshalizer.Unmarshal(proof, marshaledProof) + if err != nil { + return err + } + + st.proofsPool.AddProof(proof) + return nil } @@ -513,6 +569,12 @@ func checkBaseStorageBootstrapperArguments(args ArgsBaseStorageBootstrapper) err if check.IfNil(args.AppStatusHandler) { return process.ErrNilAppStatusHandler } + if check.IfNil(args.EnableEpochsHandler) { + return process.ErrNilEnableEpochsHandler + } + if check.IfNil(args.ProofsPool) { + return process.ErrNilProofsPool + } return nil } @@ -535,6 +597,11 @@ func (st *storageBootstrapper) restoreBlockBodyIntoPools(headerHash []byte) erro return err } + err = st.getAndApplyProofForHeader(headerHash, headerHandler) + if err != nil { + return err + } + return nil } diff --git a/process/sync/storageBootstrap/baseStorageBootstrapper_test.go b/process/sync/storageBootstrap/baseStorageBootstrapper_test.go index fd84771ea26..1a4e1d9c3ad 100644 --- a/process/sync/storageBootstrap/baseStorageBootstrapper_test.go +++ b/process/sync/storageBootstrap/baseStorageBootstrapper_test.go @@ -7,20 +7,23 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" ) -func createMockShardStorageBoostrapperArgs() ArgsBaseStorageBootstrapper { +func createMockShardStorageBootstrapperArgs() ArgsBaseStorageBootstrapper { argsBaseBootstrapper := ArgsBaseStorageBootstrapper{ BootStorer: &mock.BoostrapStorerMock{}, ForkDetector: &mock.ForkDetectorMock{}, @@ -44,6 +47,8 @@ func createMockShardStorageBoostrapperArgs() ArgsBaseStorageBootstrapper { EpochNotifier: &epochNotifierMock.EpochNotifierStub{}, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + ProofsPool: &dataRetrieverMocks.ProofsPoolMock{}, } return argsBaseBootstrapper @@ -55,7 +60,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil bootstorer should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.BootStorer = nil err := checkBaseStorageBootstrapperArguments(args) @@ -64,7 +69,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil fork detector should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.ForkDetector = nil err := checkBaseStorageBootstrapperArguments(args) @@ -73,7 +78,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil block processor should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.BlockProcessor = nil err := checkBaseStorageBootstrapperArguments(args) @@ -82,7 +87,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil chain handler should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.ChainHandler = nil err := checkBaseStorageBootstrapperArguments(args) @@ -91,7 +96,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil marshaller should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.Marshalizer = nil err := checkBaseStorageBootstrapperArguments(args) @@ -100,7 +105,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil store should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.Store = nil err := checkBaseStorageBootstrapperArguments(args) @@ -109,7 +114,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil uint64 converter should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.Uint64Converter = nil err := checkBaseStorageBootstrapperArguments(args) @@ -118,7 +123,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil shard coordinator should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.ShardCoordinator = nil err := checkBaseStorageBootstrapperArguments(args) @@ -127,7 +132,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil nodes coordinator should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.NodesCoordinator = nil err := checkBaseStorageBootstrapperArguments(args) @@ -136,7 +141,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil epoch start trigger should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.EpochStartTrigger = nil err := checkBaseStorageBootstrapperArguments(args) @@ -145,7 +150,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil block tracker should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.BlockTracker = nil err := checkBaseStorageBootstrapperArguments(args) @@ -154,7 +159,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil scheduled txs execution should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.ScheduledTxsExecutionHandler = nil err := checkBaseStorageBootstrapperArguments(args) @@ -163,7 +168,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil miniblocks provider should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.MiniblocksProvider = nil err := checkBaseStorageBootstrapperArguments(args) @@ -172,7 +177,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil epoch notifier should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.EpochNotifier = nil err := checkBaseStorageBootstrapperArguments(args) @@ -181,7 +186,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil processed mini blocks tracker should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.ProcessedMiniBlocksTracker = nil err := checkBaseStorageBootstrapperArguments(args) @@ -190,7 +195,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin t.Run("nil app status handler - should error", func(t *testing.T) { t.Parallel() - args := createMockShardStorageBoostrapperArgs() + args := createMockShardStorageBootstrapperArgs() args.AppStatusHandler = nil err := checkBaseStorageBootstrapperArguments(args) @@ -201,7 +206,7 @@ func TestBaseStorageBootstrapper_CheckBaseStorageBootstrapperArguments(t *testin func TestBaseStorageBootstrapper_RestoreBlockBodyIntoPoolsShouldErrMissingHeader(t *testing.T) { t.Parallel() - baseArgs := createMockShardStorageBoostrapperArgs() + baseArgs := createMockShardStorageBootstrapperArgs() baseArgs.Store = &storageStubs.ChainStorerStub{ GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { return &storageStubs.StorerStub{ @@ -228,7 +233,7 @@ func TestBaseStorageBootstrapper_RestoreBlockBodyIntoPoolsShouldErrMissingBody(t headerHash := []byte("header_hash") header := &block.Header{} - baseArgs := createMockShardStorageBoostrapperArgs() + baseArgs := createMockShardStorageBootstrapperArgs() baseArgs.MiniblocksProvider = &mock.MiniBlocksProviderStub{ GetMiniBlocksFromStorerCalled: func(hashes [][]byte) ([]*block.MiniblockAndHash, [][]byte) { return nil, [][]byte{[]byte("missing_hash")} @@ -258,7 +263,7 @@ func TestBaseStorageBootstrapper_RestoreBlockBodyIntoPoolsShouldErrWhenRestoreBl headerHash := []byte("header_hash") header := &block.Header{} - baseArgs := createMockShardStorageBoostrapperArgs() + baseArgs := createMockShardStorageBootstrapperArgs() baseArgs.MiniblocksProvider = &mock.MiniBlocksProviderStub{ GetMiniBlocksFromStorerCalled: func(hashes [][]byte) ([]*block.MiniblockAndHash, [][]byte) { return nil, nil @@ -292,7 +297,7 @@ func TestBaseStorageBootstrapper_RestoreBlockBodyIntoPoolsShouldWork(t *testing. headerHash := []byte("header_hash") header := &block.Header{} - baseArgs := createMockShardStorageBoostrapperArgs() + baseArgs := createMockShardStorageBootstrapperArgs() baseArgs.MiniblocksProvider = &mock.MiniBlocksProviderStub{ GetMiniBlocksFromStorerCalled: func(hashes [][]byte) ([]*block.MiniblockAndHash, [][]byte) { return nil, nil @@ -325,7 +330,7 @@ func TestBaseStorageBootstrapper_GetBlockBodyShouldErrMissingBody(t *testing.T) header := &block.Header{} - baseArgs := createMockShardStorageBoostrapperArgs() + baseArgs := createMockShardStorageBootstrapperArgs() baseArgs.MiniblocksProvider = &mock.MiniBlocksProviderStub{ GetMiniBlocksFromStorerCalled: func(hashes [][]byte) ([]*block.MiniblockAndHash, [][]byte) { return nil, [][]byte{[]byte("missing_hash")} @@ -370,7 +375,7 @@ func TestBaseStorageBootstrapper_GetBlockBodyShouldWork(t *testing.T) { } header := &block.Header{} - baseArgs := createMockShardStorageBoostrapperArgs() + baseArgs := createMockShardStorageBootstrapperArgs() baseArgs.MiniblocksProvider = &mock.MiniBlocksProviderStub{ GetMiniBlocksFromStorerCalled: func(hashes [][]byte) ([]*block.MiniblockAndHash, [][]byte) { return mbAndHashes, nil diff --git a/process/sync/storageBootstrap/metaStorageBootstrapper.go b/process/sync/storageBootstrap/metaStorageBootstrapper.go index ceac6df4f9c..c236018229f 100644 --- a/process/sync/storageBootstrap/metaStorageBootstrapper.go +++ b/process/sync/storageBootstrap/metaStorageBootstrapper.go @@ -3,6 +3,7 @@ package storageBootstrap import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/bootstrapStorage" @@ -41,6 +42,8 @@ func NewMetaStorageBootstrapper(arguments ArgsMetaStorageBootstrapper) (*metaSto epochNotifier: arguments.EpochNotifier, processedMiniBlocksTracker: arguments.ProcessedMiniBlocksTracker, appStatusHandler: arguments.AppStatusHandler, + enableEpochsHandler: arguments.EnableEpochsHandler, + proofsPool: arguments.ProofsPool, } boot := metaStorageBootstrapper{ @@ -74,6 +77,11 @@ func (msb *metaStorageBootstrapper) applyCrossNotarizedHeaders(crossNotarizedHea return err } + err = msb.getAndApplyProofForHeader(crossNotarizedHeader.Hash, header) + if err != nil { + return err + } + log.Debug("added cross notarized header in block tracker", "shard", crossNotarizedHeader.ShardId, "round", header.GetRound(), @@ -124,7 +132,7 @@ func (msb *metaStorageBootstrapper) cleanupNotarizedStorage(metaBlockHash []byte "nonce", shardHeader.GetNonce(), "hash", shardHeaderHash) - hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(shardHeader.GetShardID()) + hdrNonceHashDataUnit := dataRetriever.GetHdrNonceHashDataUnit(shardHeader.GetShardID()) storer, err := msb.store.GetStorer(hdrNonceHashDataUnit) if err != nil { log.Debug("could not get storage unit", @@ -158,6 +166,11 @@ func (msb *metaStorageBootstrapper) applySelfNotarizedHeaders( return nil, nil, err } + err = msb.getAndApplyProofForHeader(bootstrapHeaderInfo.Hash, selfNotarizedHeader) + if err != nil { + return nil, nil, err + } + log.Debug("added self notarized header in block tracker", "shard", bootstrapHeaderInfo.ShardId, "round", selfNotarizedHeader.GetRound(), diff --git a/process/sync/storageBootstrap/shardStorageBootstrapper.go b/process/sync/storageBootstrap/shardStorageBootstrapper.go index fe327de1de6..ebc8992df05 100644 --- a/process/sync/storageBootstrap/shardStorageBootstrapper.go +++ b/process/sync/storageBootstrap/shardStorageBootstrapper.go @@ -4,6 +4,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/bootstrapStorage" @@ -42,6 +43,8 @@ func NewShardStorageBootstrapper(arguments ArgsShardStorageBootstrapper) (*shard epochNotifier: arguments.EpochNotifier, processedMiniBlocksTracker: arguments.ProcessedMiniBlocksTracker, appStatusHandler: arguments.AppStatusHandler, + enableEpochsHandler: arguments.EnableEpochsHandler, + proofsPool: arguments.ProofsPool, } boot := shardStorageBootstrapper{ @@ -49,7 +52,7 @@ func NewShardStorageBootstrapper(arguments ArgsShardStorageBootstrapper) (*shard } base.bootstrapper = &boot - hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(boot.shardCoordinator.SelfId()) + hdrNonceHashDataUnit := dataRetriever.GetHdrNonceHashDataUnit(boot.shardCoordinator.SelfId()) base.headerNonceHashStore, err = boot.store.GetStorer(hdrNonceHashDataUnit) if err != nil { return nil, err @@ -87,6 +90,11 @@ func (ssb *shardStorageBootstrapper) applyCrossNotarizedHeaders(crossNotarizedHe return err } + err = ssb.getAndApplyProofForHeader(crossNotarizedHeader.Hash, metaBlock) + if err != nil { + return err + } + log.Debug("added cross notarized header in block tracker", "shard", core.MetachainShardId, "round", metaBlock.GetRound(), @@ -252,6 +260,11 @@ func (ssb *shardStorageBootstrapper) applySelfNotarizedHeaders( return nil, nil, err } + err = ssb.getAndApplyProofForHeader(selfNotarizedHeaderHash, selfNotarizedHeader) + if err != nil { + return nil, nil, err + } + selfNotarizedHeaders[index] = selfNotarizedHeader log.Debug("added self notarized header in block tracker", diff --git a/process/sync/storageBootstrap/shardStorageBootstrapper_test.go b/process/sync/storageBootstrap/shardStorageBootstrapper_test.go index f518b21b788..8ab7f337e93 100644 --- a/process/sync/storageBootstrap/shardStorageBootstrapper_test.go +++ b/process/sync/storageBootstrap/shardStorageBootstrapper_test.go @@ -8,6 +8,10 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/block/bootstrapStorage" @@ -15,14 +19,13 @@ import ( "github.com/multiversx/mx-chain-go/process/sync" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" storageMock "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestShardStorageBootstrapper_LoadFromStorageShouldWork(t *testing.T) { @@ -136,6 +139,8 @@ func TestShardStorageBootstrapper_LoadFromStorageShouldWork(t *testing.T) { }, ProcessedMiniBlocksTracker: &testscommon.ProcessedMiniBlocksTrackerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + ProofsPool: &dataRetrieverMocks.ProofsPoolMock{}, }, } @@ -153,7 +158,7 @@ func TestShardStorageBootstrapper_LoadFromStorageShouldWork(t *testing.T) { } func TestShardStorageBootstrapper_CleanupNotarizedStorageForHigherNoncesIfExist(t *testing.T) { - baseArgs := createMockShardStorageBoostrapperArgs() + baseArgs := createMockShardStorageBootstrapperArgs() bForceError := true numCalled := 0 diff --git a/process/throttle/antiflood/blackList/p2pBlackListProcessor_test.go b/process/throttle/antiflood/blackList/p2pBlackListProcessor_test.go index 0d5eee28a06..686b49031d1 100644 --- a/process/throttle/antiflood/blackList/p2pBlackListProcessor_test.go +++ b/process/throttle/antiflood/blackList/p2pBlackListProcessor_test.go @@ -7,16 +7,18 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/throttle/antiflood/blackList" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) const selfPid = "current pid" -//-------- NewP2PQuotaBlacklistProcessor +// -------- NewP2PQuotaBlacklistProcessor func TestNewP2PQuotaBlacklistProcessor_NilCacherShouldErr(t *testing.T) { t.Parallel() @@ -40,7 +42,7 @@ func TestNewP2PQuotaBlacklistProcessor_NilBlackListHandlerShouldErr(t *testing.T t.Parallel() pbp, err := blackList.NewP2PBlackListProcessor( - testscommon.NewCacherStub(), + cache.NewCacherStub(), nil, 1, 1, @@ -58,7 +60,7 @@ func TestNewP2PQuotaBlacklistProcessor_InvalidThresholdNumReceivedFloodShouldErr t.Parallel() pbp, err := blackList.NewP2PBlackListProcessor( - testscommon.NewCacherStub(), + cache.NewCacherStub(), &mock.PeerBlackListHandlerStub{}, 0, 1, @@ -76,7 +78,7 @@ func TestNewP2PQuotaBlacklistProcessor_InvalidThresholdSizeReceivedFloodShouldEr t.Parallel() pbp, err := blackList.NewP2PBlackListProcessor( - testscommon.NewCacherStub(), + cache.NewCacherStub(), &mock.PeerBlackListHandlerStub{}, 1, 0, @@ -94,7 +96,7 @@ func TestNewP2PQuotaBlacklistProcessor_InvalidNumFloodingRoundsShouldErr(t *test t.Parallel() pbp, err := blackList.NewP2PBlackListProcessor( - testscommon.NewCacherStub(), + cache.NewCacherStub(), &mock.PeerBlackListHandlerStub{}, 1, 1, @@ -112,7 +114,7 @@ func TestNewP2PQuotaBlacklistProcessor_InvalidBanDurationShouldErr(t *testing.T) t.Parallel() pbp, err := blackList.NewP2PBlackListProcessor( - testscommon.NewCacherStub(), + cache.NewCacherStub(), &mock.PeerBlackListHandlerStub{}, 1, 1, @@ -130,7 +132,7 @@ func TestNewP2PQuotaBlacklistProcessor_ShouldWork(t *testing.T) { t.Parallel() pbp, err := blackList.NewP2PBlackListProcessor( - testscommon.NewCacherStub(), + cache.NewCacherStub(), &mock.PeerBlackListHandlerStub{}, 1, 1, @@ -144,7 +146,7 @@ func TestNewP2PQuotaBlacklistProcessor_ShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- AddQuota +// ------- AddQuota func TestP2PQuotaBlacklistProcessor_AddQuotaUnderThresholdShouldNotCallGetOrPut(t *testing.T) { t.Parallel() @@ -153,7 +155,7 @@ func TestP2PQuotaBlacklistProcessor_AddQuotaUnderThresholdShouldNotCallGetOrPut( thresholdSize := uint64(20) pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ GetCalled: func(key []byte) (interface{}, bool) { assert.Fail(t, "should not have called get") return nil, false @@ -184,7 +186,7 @@ func TestP2PQuotaBlacklistProcessor_AddQuotaOverThresholdInexistentDataOnGetShou putCalled := false identifier := core.PeerID("identifier") pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ GetCalled: func(key []byte) (interface{}, bool) { return nil, false }, @@ -219,7 +221,7 @@ func TestP2PQuotaBlacklistProcessor_AddQuotaOverThresholdDataNotValidOnGetShould putCalled := false identifier := core.PeerID("identifier") pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ GetCalled: func(key []byte) (interface{}, bool) { return "invalid data", true }, @@ -255,7 +257,7 @@ func TestP2PQuotaBlacklistProcessor_AddQuotaShouldIncrement(t *testing.T) { identifier := core.PeerID("identifier") existingValue := uint32(445) pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ GetCalled: func(key []byte) (interface{}, bool) { return existingValue, true }, @@ -290,7 +292,7 @@ func TestP2PQuotaBlacklistProcessor_AddQuotaForSelfShouldNotIncrement(t *testing putCalled := false existingValue := uint32(445) pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ GetCalled: func(key []byte) (interface{}, bool) { return existingValue, true }, @@ -313,7 +315,7 @@ func TestP2PQuotaBlacklistProcessor_AddQuotaForSelfShouldNotIncrement(t *testing assert.False(t, putCalled) } -//------- ResetStatistics +// ------- ResetStatistics func TestP2PQuotaBlacklistProcessor_ResetStatisticsRemoveNilValueKey(t *testing.T) { t.Parallel() @@ -324,7 +326,7 @@ func TestP2PQuotaBlacklistProcessor_ResetStatisticsRemoveNilValueKey(t *testing. nilValKey := "nil val key" removedCalled := false pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ KeysCalled: func() [][]byte { return [][]byte{[]byte(nilValKey)} }, @@ -360,7 +362,7 @@ func TestP2PQuotaBlacklistProcessor_ResetStatisticsShouldRemoveInvalidValueKey(t invalidValKey := "invalid val key" removedCalled := false pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ KeysCalled: func() [][]byte { return [][]byte{[]byte(invalidValKey)} }, @@ -399,7 +401,7 @@ func TestP2PQuotaBlacklistProcessor_ResetStatisticsUnderNumFloodingRoundsShouldN upsertCalled := false duration := time.Second * 3892 pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ KeysCalled: func() [][]byte { return [][]byte{[]byte(key)} }, @@ -444,7 +446,7 @@ func TestP2PQuotaBlacklistProcessor_ResetStatisticsOverNumFloodingRoundsShouldBl upsertCalled := false duration := time.Second * 3892 pbp, _ := blackList.NewP2PBlackListProcessor( - &testscommon.CacherStub{ + &cache.CacherStub{ KeysCalled: func() [][]byte { return [][]byte{[]byte(key)} }, diff --git a/process/throttle/antiflood/disabled/antiflood.go b/process/throttle/antiflood/disabled/antiflood.go index cdae45f21c1..99467e63c9a 100644 --- a/process/throttle/antiflood/disabled/antiflood.go +++ b/process/throttle/antiflood/disabled/antiflood.go @@ -47,8 +47,8 @@ func (af *AntiFlood) CanProcessMessagesOnTopic(_ core.PeerID, _ string, _ uint32 return nil } -// ApplyConsensusSize does nothing -func (af *AntiFlood) ApplyConsensusSize(_ int) { +// SetConsensusSizeNotifier does nothing +func (af *AntiFlood) SetConsensusSizeNotifier(_ process.ChainParametersSubscriber, _ uint32) { } // SetDebugger returns nil diff --git a/process/throttle/antiflood/disabled/antiflood_test.go b/process/throttle/antiflood/disabled/antiflood_test.go index e1118894cc4..a5908cc6f07 100644 --- a/process/throttle/antiflood/disabled/antiflood_test.go +++ b/process/throttle/antiflood/disabled/antiflood_test.go @@ -22,7 +22,7 @@ func TestAntiFlood_ShouldNotPanic(t *testing.T) { daf.SetMaxMessagesForTopic("test", 10) daf.ResetForTopic("test") - daf.ApplyConsensusSize(0) + daf.SetConsensusSizeNotifier(nil, 0) _ = daf.CanProcessMessagesOnTopic(core.PeerID(fmt.Sprint(1)), "test", 1, 0, nil) _ = daf.CanProcessMessage(nil, core.PeerID(fmt.Sprint(2))) } diff --git a/process/throttle/antiflood/export_test.go b/process/throttle/antiflood/export_test.go index bd97917572c..25fbf8bae30 100644 --- a/process/throttle/antiflood/export_test.go +++ b/process/throttle/antiflood/export_test.go @@ -3,5 +3,8 @@ package antiflood import "github.com/multiversx/mx-chain-go/process" func (af *p2pAntiflood) Debugger() process.AntifloodDebugger { + af.mutDebugger.RLock() + defer af.mutDebugger.RUnlock() + return af.debugger } diff --git a/process/throttle/antiflood/floodPreventers/quotaFloodPreventer_test.go b/process/throttle/antiflood/floodPreventers/quotaFloodPreventer_test.go index 068ba97591d..5dc21b68e35 100644 --- a/process/throttle/antiflood/floodPreventers/quotaFloodPreventer_test.go +++ b/process/throttle/antiflood/floodPreventers/quotaFloodPreventer_test.go @@ -9,16 +9,18 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" ) func createDefaultArgument() ArgQuotaFloodPreventer { return ArgQuotaFloodPreventer{ Name: "test", - Cacher: testscommon.NewCacherStub(), + Cacher: cache.NewCacherStub(), StatusHandlers: []QuotaStatusHandler{&mock.QuotaStatusHandlerStub{}}, BaseMaxNumMessagesPerPeer: minMessages, MaxTotalSizePerPeer: minTotalSize, @@ -28,7 +30,7 @@ func createDefaultArgument() ArgQuotaFloodPreventer { } } -//------- NewQuotaFloodPreventer +// ------- NewQuotaFloodPreventer func TestNewQuotaFloodPreventer_NilCacherShouldErr(t *testing.T) { t.Parallel() @@ -128,7 +130,7 @@ func TestNewQuotaFloodPreventer_NilListShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- IncreaseLoad +// ------- IncreaseLoad func TestNewQuotaFloodPreventer_IncreaseLoadIdentifierNotPresentPutQuotaAndReturnTrue(t *testing.T) { t.Parallel() @@ -136,7 +138,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadIdentifierNotPresentPutQuotaAndRetur putWasCalled := false size := uint64(minTotalSize * 5) arg := createDefaultArgument() - arg.Cacher = &testscommon.CacherStub{ + arg.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -168,7 +170,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadNotQuotaSavedInCacheShouldPutQuotaAn putWasCalled := false size := uint64(minTotalSize * 5) arg := createDefaultArgument() - arg.Cacher = &testscommon.CacherStub{ + arg.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return "bad value", true }, @@ -205,7 +207,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadUnderMaxValuesShouldIncrementAndRetu } size := uint64(minTotalSize * 2) arg := createDefaultArgument() - arg.Cacher = &testscommon.CacherStub{ + arg.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return existingQuota, true }, @@ -219,7 +221,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadUnderMaxValuesShouldIncrementAndRetu assert.Nil(t, err) } -//------- IncreaseLoad per peer +// ------- IncreaseLoad per peer func TestNewQuotaFloodPreventer_IncreaseLoadOverMaxPeerNumMessagesShouldNotPutAndReturnFalse(t *testing.T) { t.Parallel() @@ -231,7 +233,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadOverMaxPeerNumMessagesShouldNotPutAn sizeReceivedMessages: existingSize, } arg := createDefaultArgument() - arg.Cacher = &testscommon.CacherStub{ + arg.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return existingQuota, true }, @@ -260,7 +262,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadOverMaxPeerSizeShouldNotPutAndReturn sizeReceivedMessages: existingSize, } arg := createDefaultArgument() - arg.Cacher = &testscommon.CacherStub{ + arg.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { return existingQuota, true }, @@ -284,7 +286,7 @@ func TestCountersMap_IncreaseLoadShouldWorkConcurrently(t *testing.T) { numIterations := 1000 arg := createDefaultArgument() - arg.Cacher = testscommon.NewCacherMock() + arg.Cacher = cache.NewCacherMock() qfp, _ := NewQuotaFloodPreventer(arg) wg := sync.WaitGroup{} wg.Add(numIterations) @@ -299,14 +301,14 @@ func TestCountersMap_IncreaseLoadShouldWorkConcurrently(t *testing.T) { wg.Wait() } -//------- Reset +// ------- Reset func TestCountersMap_ResetShouldCallCacherClear(t *testing.T) { t.Parallel() clearCalled := false arg := createDefaultArgument() - arg.Cacher = &testscommon.CacherStub{ + arg.Cacher = &cache.CacherStub{ ClearCalled: func() { clearCalled = true }, @@ -324,7 +326,7 @@ func TestCountersMap_ResetShouldCallCacherClear(t *testing.T) { func TestCountersMap_ResetShouldCallQuotaStatus(t *testing.T) { t.Parallel() - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() key1 := core.PeerID("key1") quota1 := "a{ numReceivedMessages: 1, @@ -391,7 +393,7 @@ func TestCountersMap_IncrementAndResetShouldWorkConcurrently(t *testing.T) { numIterations := 1000 arg := createDefaultArgument() - arg.Cacher = testscommon.NewCacherMock() + arg.Cacher = cache.NewCacherMock() qfp, _ := NewQuotaFloodPreventer(arg) wg := sync.WaitGroup{} wg.Add(numIterations + numIterations/10) @@ -418,7 +420,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadWithMockCacherShouldWork(t *testing. numMessages := uint32(100) arg := createDefaultArgument() - arg.Cacher = testscommon.NewCacherMock() + arg.Cacher = cache.NewCacherMock() arg.BaseMaxNumMessagesPerPeer = numMessages arg.MaxTotalSizePerPeer = math.MaxUint64 arg.PercentReserved = float32(17) @@ -437,7 +439,7 @@ func TestNewQuotaFloodPreventer_IncreaseLoadWithMockCacherShouldWork(t *testing. } } -//------- ApplyConsensusSize +// ------- ApplyConsensusSize func TestQuotaFloodPreventer_ApplyConsensusSizeInvalidConsensusSize(t *testing.T) { t.Parallel() @@ -468,7 +470,7 @@ func TestQuotaFloodPreventer_ApplyConsensusShouldWork(t *testing.T) { t.Parallel() arg := createDefaultArgument() - arg.Cacher = testscommon.NewCacherMock() + arg.Cacher = cache.NewCacherMock() arg.BaseMaxNumMessagesPerPeer = 2000 arg.IncreaseThreshold = 1000 arg.IncreaseFactor = 0.25 diff --git a/process/throttle/antiflood/p2pAntiflood.go b/process/throttle/antiflood/p2pAntiflood.go index 621a0af69a8..747aca92c84 100644 --- a/process/throttle/antiflood/p2pAntiflood.go +++ b/process/throttle/antiflood/p2pAntiflood.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/throttle/antiflood/disabled" @@ -27,6 +28,8 @@ type p2pAntiflood struct { peerValidatorMapper process.PeerValidatorMapper mapTopicsFromAll map[string]struct{} mutTopicCheck sync.RWMutex + shardID uint32 + mutShardID sync.RWMutex } // NewP2PAntiflood creates a new p2p anti flood protection mechanism built on top of a flood preventer implementation. @@ -57,6 +60,31 @@ func NewP2PAntiflood( }, nil } +// SetConsensusSizeNotifier sets the consensus size notifier +func (af *p2pAntiflood) SetConsensusSizeNotifier(chainParametersNotifier process.ChainParametersSubscriber, shardID uint32) { + af.mutShardID.Lock() + af.shardID = shardID + af.mutShardID.Unlock() + + chainParametersNotifier.RegisterNotifyHandler(af) +} + +// ChainParametersChanged will be called when new chain parameters are confirmed on the network +func (af *p2pAntiflood) ChainParametersChanged(chainParameters config.ChainParametersByEpochConfig) { + af.mutShardID.RLock() + shardID := af.shardID + af.mutShardID.RUnlock() + + size := chainParameters.ShardConsensusGroupSize + if shardID == core.MetachainShardId { + size = chainParameters.MetachainConsensusGroupSize + } + + for _, fp := range af.floodPreventers { + fp.ApplyConsensusSize(int(size)) + } +} + // CanProcessMessage signals if a p2p message can be processed or not func (af *p2pAntiflood) CanProcessMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID) error { if message == nil { @@ -210,13 +238,6 @@ func (af *p2pAntiflood) ResetForTopic(topic string) { af.topicPreventer.ResetForTopic(topic) } -// ApplyConsensusSize applies the consensus size on all contained flood preventers -func (af *p2pAntiflood) ApplyConsensusSize(size int) { - for _, fp := range af.floodPreventers { - fp.ApplyConsensusSize(size) - } -} - // SetDebugger sets the antiflood debugger func (af *p2pAntiflood) SetDebugger(debugger process.AntifloodDebugger) error { if check.IfNil(debugger) { @@ -257,6 +278,9 @@ func (af *p2pAntiflood) BlacklistPeer(peer core.PeerID, reason string, duration // Close will call the close function on all sub components func (af *p2pAntiflood) Close() error { + af.mutDebugger.RLock() + defer af.mutDebugger.RUnlock() + return af.debugger.Close() } diff --git a/process/throttle/antiflood/p2pAntiflood_test.go b/process/throttle/antiflood/p2pAntiflood_test.go index 21ea5e99a8a..97637e2b621 100644 --- a/process/throttle/antiflood/p2pAntiflood_test.go +++ b/process/throttle/antiflood/p2pAntiflood_test.go @@ -2,17 +2,21 @@ package antiflood_test import ( "errors" + "sync" "sync/atomic" "testing" "time" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common/chainparametersnotifier" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/throttle/antiflood" "github.com/multiversx/mx-chain-go/process/throttle/antiflood/disabled" + "github.com/multiversx/mx-chain-go/testscommon/commonmocks" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/stretchr/testify/assert" ) @@ -307,24 +311,33 @@ func TestP2pAntiflood_ResetForTopicSetMaxMessagesShouldWork(t *testing.T) { assert.Equal(t, setMaxMessagesForTopicNum, setMaxMessagesForTopicParameter2) } -func TestP2pAntiflood_ApplyConsensusSize(t *testing.T) { +func TestP2pAntiflood_SetConsensusSizeNotifier(t *testing.T) { t.Parallel() wasCalled := false expectedSize := 878264 + testShardId := uint32(5) + var actualSize int afm, _ := antiflood.NewP2PAntiflood( &mock.PeerBlackListHandlerStub{}, &mock.TopicAntiFloodStub{}, &mock.FloodPreventerStub{ ApplyConsensusSizeCalled: func(size int) { - assert.Equal(t, expectedSize, size) + actualSize = size wasCalled = true }, }, ) - afm.ApplyConsensusSize(expectedSize) + chainParamsSubscriber := chainparametersnotifier.NewChainParametersNotifier() + afm.SetConsensusSizeNotifier(chainParamsSubscriber, testShardId) + + chainParamsSubscriber.UpdateCurrentChainParameters(config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(expectedSize), + }) + assert.True(t, wasCalled) + assert.Equal(t, expectedSize, actualSize) } func TestP2pAntiflood_SetDebuggerNilDebuggerShouldErr(t *testing.T) { @@ -464,3 +477,61 @@ func TestP2pAntiflood_IsOriginatorEligibleForTopic(t *testing.T) { err = afm.IsOriginatorEligibleForTopic(core.PeerID(validatorPID), "topic") assert.Nil(t, err) } + +func TestP2pAntiflood_ConcurrentOperations(t *testing.T) { + afm, _ := antiflood.NewP2PAntiflood( + &mock.PeerBlackListHandlerStub{}, + &mock.TopicAntiFloodStub{ + IncreaseLoadCalled: func(pid core.PeerID, topic string, numMessages uint32) error { + if topic == "should error" { + return errors.New("error") + } + + return nil + }, + }, + &mock.FloodPreventerStub{}, + ) + + numOperations := 500 + wg := sync.WaitGroup{} + wg.Add(numOperations) + for i := 0; i < numOperations; i++ { + go func(idx int) { + switch idx { + case 0: + afm.SetConsensusSizeNotifier(&commonmocks.ChainParametersNotifierStub{}, 1) + case 1: + afm.ChainParametersChanged(config.ChainParametersByEpochConfig{}) + case 2: + _ = afm.Close() + case 3: + _ = afm.CanProcessMessage(&p2pmocks.P2PMessageMock{}, "peer") + case 4: + afm.BlacklistPeer("peer", "reason", time.Millisecond) + case 5: + _ = afm.CanProcessMessagesOnTopic("peer", "topic", 37, 39, []byte("sequence")) + case 6: + _ = afm.IsOriginatorEligibleForTopic("peer", "topic") + case 7: + afm.ResetForTopic("topic") + case 8: + _ = afm.SetDebugger(&disabled.AntifloodDebugger{}) + case 9: + afm.SetMaxMessagesForTopic("topic", 37) + case 10: + afm.SetTopicsForAll("topic", "topic1") + case 11: + _ = afm.Debugger() + case 12: + _ = afm.SetPeerValidatorMapper(&mock.PeerShardResolverStub{}) + case 13: + _ = afm.CanProcessMessagesOnTopic("peer", "should error", 37, 39, []byte("sequence")) + } + + wg.Done() + }(i % 14) + } + + wg.Wait() +} diff --git a/process/track/argBlockProcessor.go b/process/track/argBlockProcessor.go index 0b7b02b20c9..6194e1c5e60 100644 --- a/process/track/argBlockProcessor.go +++ b/process/track/argBlockProcessor.go @@ -1,6 +1,11 @@ package track import ( + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" ) @@ -19,4 +24,10 @@ type ArgBlockProcessor struct { SelfNotarizedHeadersNotifier blockNotifierHandler FinalMetachainHeadersNotifier blockNotifierHandler RoundHandler process.RoundHandler + EnableEpochsHandler common.EnableEpochsHandler + ProofsPool process.ProofsPool + Marshaller marshal.Marshalizer + Hasher hashing.Hasher + HeadersPool dataRetriever.HeadersPool + IsImportDBMode bool } diff --git a/process/track/argBlockTrack.go b/process/track/argBlockTrack.go index ea655d3937b..e3512966d00 100644 --- a/process/track/argBlockTrack.go +++ b/process/track/argBlockTrack.go @@ -4,6 +4,8 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" @@ -12,17 +14,21 @@ import ( // ArgBaseTracker holds all dependencies required by the process data factory in order to create // new instances of shard/meta block tracker type ArgBaseTracker struct { - Hasher hashing.Hasher - HeaderValidator process.HeaderConstructionValidator - Marshalizer marshal.Marshalizer - RequestHandler process.RequestHandler - RoundHandler process.RoundHandler - ShardCoordinator sharding.Coordinator - Store dataRetriever.StorageService - StartHeaders map[uint32]data.HeaderHandler - PoolsHolder dataRetriever.PoolsHolder - WhitelistHandler process.WhiteListHandler - FeeHandler process.FeeHandler + Hasher hashing.Hasher + HeaderValidator process.HeaderConstructionValidator + Marshalizer marshal.Marshalizer + RequestHandler process.RequestHandler + RoundHandler process.RoundHandler + ShardCoordinator sharding.Coordinator + Store dataRetriever.StorageService + StartHeaders map[uint32]data.HeaderHandler + PoolsHolder dataRetriever.PoolsHolder + WhitelistHandler process.WhiteListHandler + FeeHandler process.FeeHandler + EnableEpochsHandler common.EnableEpochsHandler + ProofsPool process.ProofsPool + EpochChangeGracePeriodHandler common.EpochChangeGracePeriodHandler + IsImportDBMode bool } // ArgShardTracker holds all dependencies required by the process data factory in order to create diff --git a/process/track/baseBlockTrack.go b/process/track/baseBlockTrack.go index 22eb1c86cc1..e99be3c9fa7 100644 --- a/process/track/baseBlockTrack.go +++ b/process/track/baseBlockTrack.go @@ -12,10 +12,12 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + logger "github.com/multiversx/mx-chain-logger-go" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" - "github.com/multiversx/mx-chain-logger-go" ) var _ process.ValidityAttester = (*baseBlockTrack)(nil) @@ -35,6 +37,7 @@ type baseBlockTrack struct { roundHandler process.RoundHandler shardCoordinator sharding.Coordinator headersPool dataRetriever.HeadersPool + proofsPool dataRetriever.ProofsPool store dataRetriever.StorageService blockProcessor blockProcessorHandler @@ -47,6 +50,8 @@ type baseBlockTrack struct { blockBalancer blockBalancerHandler whitelistHandler process.WhiteListHandler feeHandler process.FeeHandler + enableEpochsHandler common.EnableEpochsHandler + epochChangeGracePeriodHandler common.EpochChangeGracePeriodHandler mutHeaders sync.RWMutex headers map[uint32]map[uint64][]*HeaderInfo @@ -103,6 +108,7 @@ func createBaseBlockTrack(arguments ArgBaseTracker) (*baseBlockTrack, error) { roundHandler: arguments.RoundHandler, shardCoordinator: arguments.ShardCoordinator, headersPool: arguments.PoolsHolder.Headers(), + proofsPool: arguments.PoolsHolder.Proofs(), store: arguments.Store, crossNotarizer: crossNotarizer, selfNotarizer: selfNotarizer, @@ -114,12 +120,46 @@ func createBaseBlockTrack(arguments ArgBaseTracker) (*baseBlockTrack, error) { maxNumHeadersToKeepPerShard: maxNumHeadersToKeepPerShard, whitelistHandler: arguments.WhitelistHandler, feeHandler: arguments.FeeHandler, + enableEpochsHandler: arguments.EnableEpochsHandler, + epochChangeGracePeriodHandler: arguments.EpochChangeGracePeriodHandler, } return bbt, nil } +func (bbt *baseBlockTrack) receivedProof(proof data.HeaderProofHandler) { + if check.IfNil(proof) { + return + } + + headerHash := proof.GetHeaderHash() + header, err := bbt.getHeaderForProof(proof) + if err != nil { + log.Debug("baseBlockTrack.receivedProof with missing header", "headerHash", headerHash) + return + } + log.Debug("received proof from network in block tracker", + "shard", proof.GetHeaderShardId(), + "epoch", proof.GetHeaderEpoch(), + "round", proof.GetHeaderRound(), + "nonce", proof.GetHeaderNonce(), + "hash", proof.GetHeaderHash(), + ) + + bbt.receivedHeader(header, headerHash) +} + +func (bbt *baseBlockTrack) getHeaderForProof(proof data.HeaderProofHandler) (data.HeaderHandler, error) { + return process.GetHeader(proof.GetHeaderHash(), bbt.headersPool, bbt.store, bbt.marshalizer, proof.GetHeaderShardId()) +} + func (bbt *baseBlockTrack) receivedHeader(headerHandler data.HeaderHandler, headerHash []byte) { + if common.IsProofsFlagEnabledForHeader(bbt.enableEpochsHandler, headerHandler) { + if !bbt.proofsPool.HasProof(headerHandler.GetShardID(), headerHash) { + return + } + } + if headerHandler.GetShardID() == core.MetachainShardId { bbt.receivedMetaBlock(headerHandler, headerHash) return @@ -784,12 +824,21 @@ func checkTrackerNilParameters(arguments ArgBaseTracker) error { if check.IfNil(arguments.PoolsHolder.Headers()) { return process.ErrNilHeadersDataPool } + if check.IfNil(arguments.PoolsHolder.Proofs()) { + return process.ErrNilProofsPool + } if check.IfNil(arguments.FeeHandler) { return process.ErrNilEconomicsData } if check.IfNil(arguments.WhitelistHandler) { return process.ErrNilWhiteListHandler } + if check.IfNil(arguments.EnableEpochsHandler) { + return process.ErrNilEnableEpochsHandler + } + if check.IfNil(arguments.EpochChangeGracePeriodHandler) { + return process.ErrNilEpochChangeGracePeriodHandler + } return nil } diff --git a/process/track/baseBlockTrack_test.go b/process/track/baseBlockTrack_test.go index 04927db27fc..4846ef5e8c5 100644 --- a/process/track/baseBlockTrack_test.go +++ b/process/track/baseBlockTrack_test.go @@ -10,6 +10,12 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common/graceperiod" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" processBlock "github.com/multiversx/mx-chain-go/process/block" @@ -22,10 +28,8 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) const maxGasLimitPerBlock = uint64(1500000000) @@ -106,8 +110,9 @@ func CreateShardTrackerMockArguments() track.ArgShardTracker { shardCoordinatorMock := mock.NewMultipleShardsCoordinatorMock() genesisBlocks := createGenesisBlocks(shardCoordinatorMock) argsHeaderValidator := processBlock.ArgsHeaderValidator{ - Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } headerValidator, _ := processBlock.NewHeaderValidator(argsHeaderValidator) whitelistHandler := &testscommon.WhiteListHandlerStub{} @@ -120,6 +125,13 @@ func CreateShardTrackerMockArguments() track.ArgShardTracker { }, } + epochChangeGracePeriod, _ := graceperiod.NewEpochChangeGracePeriod( + []config.EpochChangeGracePeriodByEpoch{ + { + EnableEpoch: 0, + GracePeriodInRounds: 1, + }}) + arguments := track.ArgShardTracker{ ArgBaseTracker: track.ArgBaseTracker{ Hasher: &hashingMocks.HasherMock{}, @@ -133,6 +145,13 @@ func CreateShardTrackerMockArguments() track.ArgShardTracker { PoolsHolder: dataRetrieverMock.NewPoolsHolderMock(), WhitelistHandler: whitelistHandler, FeeHandler: feeHandler, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + }, + EpochChangeGracePeriodHandler: epochChangeGracePeriod, + ProofsPool: &dataRetrieverMock.ProofsPoolMock{}, }, } @@ -144,8 +163,9 @@ func CreateMetaTrackerMockArguments() track.ArgMetaTracker { shardCoordinatorMock.CurrentShard = core.MetachainShardId genesisBlocks := createGenesisBlocks(shardCoordinatorMock) argsHeaderValidator := processBlock.ArgsHeaderValidator{ - Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } headerValidator, _ := processBlock.NewHeaderValidator(argsHeaderValidator) whitelistHandler := &testscommon.WhiteListHandlerStub{} @@ -158,6 +178,13 @@ func CreateMetaTrackerMockArguments() track.ArgMetaTracker { }, } + epochChangeGracePeriod, _ := graceperiod.NewEpochChangeGracePeriod( + []config.EpochChangeGracePeriodByEpoch{ + { + EnableEpoch: 0, + GracePeriodInRounds: 1, + }}) + arguments := track.ArgMetaTracker{ ArgBaseTracker: track.ArgBaseTracker{ Hasher: &hashingMocks.HasherMock{}, @@ -171,6 +198,13 @@ func CreateMetaTrackerMockArguments() track.ArgMetaTracker { PoolsHolder: dataRetrieverMock.NewPoolsHolderMock(), WhitelistHandler: whitelistHandler, FeeHandler: feeHandler, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + }, + EpochChangeGracePeriodHandler: epochChangeGracePeriod, + ProofsPool: &dataRetrieverMock.ProofsPoolMock{}, }, } @@ -181,8 +215,9 @@ func CreateBaseTrackerMockArguments() track.ArgBaseTracker { shardCoordinatorMock := mock.NewMultipleShardsCoordinatorMock() genesisBlocks := createGenesisBlocks(shardCoordinatorMock) argsHeaderValidator := processBlock.ArgsHeaderValidator{ - Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } headerValidator, _ := processBlock.NewHeaderValidator(argsHeaderValidator) feeHandler := &economicsmocks.EconomicsHandlerMock{ @@ -194,16 +229,26 @@ func CreateBaseTrackerMockArguments() track.ArgBaseTracker { }, } + epochChangeGracePeriod, _ := graceperiod.NewEpochChangeGracePeriod( + []config.EpochChangeGracePeriodByEpoch{ + { + EnableEpoch: 0, + GracePeriodInRounds: 1, + }}) + arguments := track.ArgBaseTracker{ - Hasher: &hashingMocks.HasherMock{}, - HeaderValidator: headerValidator, - Marshalizer: &mock.MarshalizerMock{}, - RequestHandler: &testscommon.RequestHandlerStub{}, - RoundHandler: &mock.RoundHandlerMock{}, - ShardCoordinator: shardCoordinatorMock, - Store: initStore(), - StartHeaders: genesisBlocks, - FeeHandler: feeHandler, + Hasher: &hashingMocks.HasherMock{}, + HeaderValidator: headerValidator, + Marshalizer: &mock.MarshalizerMock{}, + RequestHandler: &testscommon.RequestHandlerStub{}, + RoundHandler: &mock.RoundHandlerMock{}, + ShardCoordinator: shardCoordinatorMock, + Store: initStore(), + StartHeaders: genesisBlocks, + FeeHandler: feeHandler, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + EpochChangeGracePeriodHandler: epochChangeGracePeriod, + ProofsPool: &dataRetrieverMock.ProofsPoolMock{}, } return arguments @@ -397,6 +442,9 @@ func TestShardGetSelfHeaders_ShouldWork(t *testing.T) { }, } }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } sbt, _ := track.NewShardBlockTrack(shardArguments) @@ -434,6 +482,9 @@ func TestMetaGetSelfHeaders_ShouldWork(t *testing.T) { }, } }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } mbt, _ := track.NewMetaBlockTrack(metaArguments) @@ -1678,6 +1729,9 @@ func TestAddHeaderFromPool_ShouldWork(t *testing.T) { }, } }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, } sbt, _ := track.NewShardBlockTrack(shardArguments) @@ -2250,7 +2304,7 @@ func TestComputeLongestChain_ShouldWorkWithLongestChain(t *testing.T) { assert.Equal(t, longestChain+chains-1, uint64(len(headers))) } -//------- CheckBlockAgainstRoundHandler +// ------- CheckBlockAgainstRoundHandler func TestBaseBlockTrack_CheckBlockAgainstRoundHandlerNilHeaderShouldErr(t *testing.T) { t.Parallel() @@ -2299,7 +2353,7 @@ func TestBaseBlockTrack_CheckBlockAgainstRoundHandlerShouldWork(t *testing.T) { assert.Nil(t, err) } -//------- CheckBlockAgainstFinal +// ------- CheckBlockAgainstFinal func TestBaseBlockTrack_CheckBlockAgainstFinalNilHeaderShouldErr(t *testing.T) { t.Parallel() diff --git a/process/track/blockProcessor.go b/process/track/blockProcessor.go index e24ff02e35d..04eb126b077 100644 --- a/process/track/blockProcessor.go +++ b/process/track/blockProcessor.go @@ -4,9 +4,13 @@ import ( "sort" "github.com/multiversx/mx-chain-core-go/core" - "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding" ) @@ -25,6 +29,13 @@ type blockProcessor struct { finalMetachainHeadersNotifier blockNotifierHandler roundHandler process.RoundHandler + enableEpochsHandler common.EnableEpochsHandler + proofsPool process.ProofsPool + marshaller marshal.Marshalizer + hasher hashing.Hasher + headersPool dataRetriever.HeadersPool + isImportDBMode bool + blockFinality uint64 } @@ -47,6 +58,12 @@ func NewBlockProcessor(arguments ArgBlockProcessor) (*blockProcessor, error) { selfNotarizedHeadersNotifier: arguments.SelfNotarizedHeadersNotifier, finalMetachainHeadersNotifier: arguments.FinalMetachainHeadersNotifier, roundHandler: arguments.RoundHandler, + enableEpochsHandler: arguments.EnableEpochsHandler, + proofsPool: arguments.ProofsPool, + headersPool: arguments.HeadersPool, + marshaller: arguments.Marshaller, + hasher: arguments.Hasher, + isImportDBMode: arguments.IsImportDBMode, } bp.blockFinality = process.BlockFinality @@ -154,7 +171,7 @@ func (bp *blockProcessor) doJobOnReceivedMetachainHeader() { } } - sortedHeaders, _ := bp.blockTracker.SortHeadersFromNonce(core.MetachainShardId, header.GetNonce()+1) + sortedHeaders, sortedHeadersHashes := bp.blockTracker.SortHeadersFromNonce(core.MetachainShardId, header.GetNonce()+1) if len(sortedHeaders) == 0 { return } @@ -162,7 +179,7 @@ func (bp *blockProcessor) doJobOnReceivedMetachainHeader() { finalMetachainHeaders := make([]data.HeaderHandler, 0) finalMetachainHeadersHashes := make([][]byte, 0) - err = bp.checkHeaderFinality(header, sortedHeaders, 0) + err = bp.checkHeaderFinality(header, sortedHeaders, sortedHeadersHashes, 0) if err == nil { finalMetachainHeaders = append(finalMetachainHeaders, header) finalMetachainHeadersHashes = append(finalMetachainHeadersHashes, headerHash) @@ -234,14 +251,15 @@ func (bp *blockProcessor) ComputeLongestChain(shardID uint32, header data.Header go bp.requestHeadersIfNeeded(header, sortedHeaders, headers) }() - sortedHeaders, sortedHeadersHashes = bp.blockTracker.SortHeadersFromNonce(shardID, header.GetNonce()+1) + startingNonce := header.GetNonce() + 1 + sortedHeaders, sortedHeadersHashes = bp.blockTracker.SortHeadersFromNonce(shardID, startingNonce) if len(sortedHeaders) == 0 { return headers, headersHashes } longestChainHeadersIndexes := make([]int, 0) headersIndexes := make([]int, 0) - bp.getNextHeader(&longestChainHeadersIndexes, headersIndexes, header, sortedHeaders, 0) + bp.getNextHeader(&longestChainHeadersIndexes, headersIndexes, header, sortedHeaders, sortedHeadersHashes, 0) for _, index := range longestChainHeadersIndexes { headers = append(headers, sortedHeaders[index]) @@ -256,6 +274,7 @@ func (bp *blockProcessor) getNextHeader( headersIndexes []int, prevHeader data.HeaderHandler, sortedHeaders []data.HeaderHandler, + sortedHeadersHashes [][]byte, index int, ) { defer func() { @@ -279,13 +298,13 @@ func (bp *blockProcessor) getNextHeader( continue } - err = bp.checkHeaderFinality(currHeader, sortedHeaders, i+1) + err = bp.checkHeaderFinality(currHeader, sortedHeaders, sortedHeadersHashes, i+1) if err != nil { continue } headersIndexes = append(headersIndexes, i) - bp.getNextHeader(longestChainHeadersIndexes, headersIndexes, currHeader, sortedHeaders, i+1) + bp.getNextHeader(longestChainHeadersIndexes, headersIndexes, currHeader, sortedHeaders, sortedHeadersHashes, i+1) headersIndexes = headersIndexes[:len(headersIndexes)-1] } } @@ -293,16 +312,28 @@ func (bp *blockProcessor) getNextHeader( func (bp *blockProcessor) checkHeaderFinality( header data.HeaderHandler, sortedHeaders []data.HeaderHandler, + sortedHeadersHashes [][]byte, index int, ) error { - if check.IfNil(header) { return process.ErrNilBlockHeader } + if bp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, header.GetEpoch()) { + // the index in argument is for the next block after header + hashIndex := index + if index > 0 { + hashIndex = index - 1 + } + if bp.proofsPool.HasProof(header.GetShardID(), sortedHeadersHashes[hashIndex]) { + return nil + } + + return process.ErrHeaderNotFinal + } + prevHeader := header numFinalityAttestingHeaders := uint64(0) - for i := index; i < len(sortedHeaders); i++ { currHeader := sortedHeaders[i] if numFinalityAttestingHeaders >= bp.blockFinality || currHeader.GetNonce() > prevHeader.GetNonce()+1 { @@ -314,6 +345,16 @@ func (bp *blockProcessor) checkHeaderFinality( continue } + // if the currentHeader(the one that should confirm the finality of the prev) + // is the epoch start block of equivalent messages, we must check for its proof as well + if bp.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, currHeader.GetEpoch()) { + if bp.proofsPool.HasProof(currHeader.GetShardID(), sortedHeadersHashes[i]) { + return nil + } + + return process.ErrHeaderNotFinal + } + prevHeader = currHeader numFinalityAttestingHeaders += 1 } @@ -425,7 +466,16 @@ func (bp *blockProcessor) requestHeadersIfNothingNewIsReceived( "chronology round", bp.roundHandler.Index(), "highest round in received headers", highestRoundInReceivedHeaders) - bp.requestHeaders(latestValidHeader.GetShardID(), latestValidHeader.GetNonce()+1) + fromNonce := latestValidHeader.GetNonce() + shardID := latestValidHeader.GetShardID() + // force the trigger to be activated by removing the start of epoch block on Andromeda activation + header, headerHash, err := process.GetMetaHeaderFromPoolWithNonce(fromNonce, bp.headersPool) + isHeaderStartOfEpochForAndromedaActivation := err == nil && shardID == common.MetachainShardId && + common.IsEpochChangeBlockForFlagActivation(header, bp.enableEpochsHandler, common.AndromedaFlag) + if isHeaderStartOfEpochForAndromedaActivation { + bp.headersPool.RemoveHeaderByHash(headerHash) + } + bp.requestHeaders(shardID, fromNonce) } func (bp *blockProcessor) requestHeaders(shardID uint32, fromNonce uint64) { @@ -439,8 +489,10 @@ func (bp *blockProcessor) requestHeaders(shardID uint32, fromNonce uint64) { if shardID == core.MetachainShardId { bp.requestHandler.RequestMetaHeaderByNonce(nonce) + bp.requestHandler.RequestEquivalentProofByNonce(core.MetachainShardId, nonce) } else { bp.requestHandler.RequestShardHeaderByNonce(shardID, nonce) + bp.requestHandler.RequestEquivalentProofByNonce(shardID, nonce) } } } @@ -484,6 +536,21 @@ func checkBlockProcessorNilParameters(arguments ArgBlockProcessor) error { if check.IfNil(arguments.RoundHandler) { return ErrNilRoundHandler } + if check.IfNil(arguments.EnableEpochsHandler) { + return process.ErrNilEnableEpochsHandler + } + if check.IfNil(arguments.ProofsPool) { + return ErrNilProofsPool + } + if check.IfNil(arguments.Marshaller) { + return process.ErrNilMarshalizer + } + if check.IfNilReflect(arguments.Hasher) { + return process.ErrNilHasher + } + if check.IfNil(arguments.HeadersPool) { + return process.ErrNilHeadersDataPool + } return nil } diff --git a/process/track/blockProcessor_test.go b/process/track/blockProcessor_test.go index ad30bd35e06..4ab8bda8e15 100644 --- a/process/track/blockProcessor_test.go +++ b/process/track/blockProcessor_test.go @@ -7,25 +7,31 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/pool" "github.com/multiversx/mx-chain-core-go/data" dataBlock "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/process" processBlock "github.com/multiversx/mx-chain-go/process/block" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/track" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func CreateBlockProcessorMockArguments() track.ArgBlockProcessor { shardCoordinatorMock := mock.NewMultipleShardsCoordinatorMock() argsHeaderValidator := processBlock.ArgsHeaderValidator{ - Hasher: &hashingMocks.HasherMock{}, - Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Marshalizer: &mock.MarshalizerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } headerValidator, _ := processBlock.NewHeaderValidator(argsHeaderValidator) @@ -56,7 +62,12 @@ func CreateBlockProcessorMockArguments() track.ArgBlockProcessor { return 1 }, }, - RoundHandler: &mock.RoundHandlerMock{}, + RoundHandler: &mock.RoundHandlerMock{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + ProofsPool: &dataRetriever.ProofsPoolMock{}, + Marshaller: &testscommon.MarshallerStub{}, + Hasher: &hashingMocks.HasherMock{}, + HeadersPool: &pool.HeadersPoolStub{}, } return arguments @@ -172,6 +183,50 @@ func TestNewBlockProcessor_ShouldErrFinalMetachainHeadersNotifier(t *testing.T) assert.Nil(t, bp) } +func TestNewBlockProcessor_ShouldErrNilEnableEpochsHandler(t *testing.T) { + t.Parallel() + + blockProcessorArguments := CreateBlockProcessorMockArguments() + blockProcessorArguments.EnableEpochsHandler = nil + bp, err := track.NewBlockProcessor(blockProcessorArguments) + + assert.Equal(t, process.ErrNilEnableEpochsHandler, err) + assert.Nil(t, bp) +} + +func TestNewBlockProcessor_ShouldErrNilProofsPool(t *testing.T) { + t.Parallel() + + blockProcessorArguments := CreateBlockProcessorMockArguments() + blockProcessorArguments.ProofsPool = nil + bp, err := track.NewBlockProcessor(blockProcessorArguments) + + assert.Equal(t, track.ErrNilProofsPool, err) + assert.Nil(t, bp) +} + +func TestNewBlockProcessor_ShouldErrNilMarshaller(t *testing.T) { + t.Parallel() + + blockProcessorArguments := CreateBlockProcessorMockArguments() + blockProcessorArguments.Marshaller = nil + bp, err := track.NewBlockProcessor(blockProcessorArguments) + + assert.Equal(t, process.ErrNilMarshalizer, err) + assert.Nil(t, bp) +} + +func TestNewBlockProcessor_ShouldErrNilHasher(t *testing.T) { + t.Parallel() + + blockProcessorArguments := CreateBlockProcessorMockArguments() + blockProcessorArguments.Hasher = nil + bp, err := track.NewBlockProcessor(blockProcessorArguments) + + assert.Equal(t, process.ErrNilHasher, err) + assert.Nil(t, bp) +} + func TestNewBlockProcessor_ShouldErrNilRoundHandler(t *testing.T) { t.Parallel() @@ -553,7 +608,7 @@ func TestGetNextHeader_ShouldReturnEmptySliceWhenPrevHeaderIsNil(t *testing.T) { longestChainHeadersIndexes := make([]int, 0) headersIndexes := make([]int, 0) sortedHeaders := []data.HeaderHandler{&dataBlock.Header{Nonce: 1}} - bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, nil, sortedHeaders, 0) + bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, nil, sortedHeaders, [][]byte{}, 0) assert.Equal(t, 0, len(longestChainHeadersIndexes)) } @@ -568,7 +623,7 @@ func TestGetNextHeader_ShouldReturnEmptySliceWhenSortedHeadersHaveHigherNonces(t headersIndexes := make([]int, 0) prevHeader := &dataBlock.Header{} sortedHeaders := []data.HeaderHandler{&dataBlock.Header{Nonce: 2}} - bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, prevHeader, sortedHeaders, 0) + bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, prevHeader, sortedHeaders, [][]byte{}, 0) assert.Equal(t, 0, len(longestChainHeadersIndexes)) } @@ -583,7 +638,7 @@ func TestGetNextHeader_ShouldReturnEmptySliceWhenHeaderConstructionIsNotValid(t headersIndexes := make([]int, 0) prevHeader := &dataBlock.Header{} sortedHeaders := []data.HeaderHandler{&dataBlock.Header{Nonce: 1}} - bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, prevHeader, sortedHeaders, 0) + bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, prevHeader, sortedHeaders, [][]byte{}, 0) assert.Equal(t, 0, len(longestChainHeadersIndexes)) } @@ -614,7 +669,7 @@ func TestGetNextHeader_ShouldReturnEmptySliceWhenHeaderFinalityIsNotChecked(t *t } sortedHeaders := []data.HeaderHandler{header2} - bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, header1, sortedHeaders, 0) + bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, header1, sortedHeaders, [][]byte{}, 0) assert.Equal(t, 0, len(longestChainHeadersIndexes)) } @@ -653,7 +708,7 @@ func TestGetNextHeader_ShouldWork(t *testing.T) { } sortedHeaders := []data.HeaderHandler{header2, header3} - bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, header1, sortedHeaders, 0) + bp.GetNextHeader(&longestChainHeadersIndexes, headersIndexes, header1, sortedHeaders, [][]byte{}, 0) require.Equal(t, 1, len(longestChainHeadersIndexes)) assert.Equal(t, 0, longestChainHeadersIndexes[0]) @@ -666,7 +721,7 @@ func TestCheckHeaderFinality_ShouldErrNilBlockHeader(t *testing.T) { bp, _ := track.NewBlockProcessor(blockProcessorArguments) sortedHeaders := []data.HeaderHandler{&dataBlock.Header{Nonce: 1}} - err := bp.CheckHeaderFinality(nil, sortedHeaders, 0) + err := bp.CheckHeaderFinality(nil, sortedHeaders, [][]byte{}, 0) assert.Equal(t, process.ErrNilBlockHeader, err) } @@ -679,7 +734,7 @@ func TestCheckHeaderFinality_ShouldErrHeaderNotFinal(t *testing.T) { header := &dataBlock.Header{} sortedHeaders := []data.HeaderHandler{&dataBlock.Header{Nonce: 1}} - err := bp.CheckHeaderFinality(header, sortedHeaders, 0) + err := bp.CheckHeaderFinality(header, sortedHeaders, [][]byte{}, 0) assert.Equal(t, process.ErrHeaderNotFinal, err) } @@ -707,7 +762,7 @@ func TestCheckHeaderFinality_ShouldWork(t *testing.T) { } sortedHeaders := []data.HeaderHandler{header2} - err := bp.CheckHeaderFinality(header1, sortedHeaders, 0) + err := bp.CheckHeaderFinality(header1, sortedHeaders, [][]byte{}, 0) assert.Nil(t, err) } diff --git a/process/track/errors.go b/process/track/errors.go index 2c9a3a5c297..220863ce86e 100644 --- a/process/track/errors.go +++ b/process/track/errors.go @@ -33,3 +33,6 @@ var ErrNilRoundHandler = errors.New("nil roundHandler") // ErrNilKeysHandler signals that a nil keys handler was provided var ErrNilKeysHandler = errors.New("nil keys handler") + +// ErrNilProofsPool signals that a nil proofs pool has been provided +var ErrNilProofsPool = errors.New("nil proofs pool") diff --git a/process/track/export_test.go b/process/track/export_test.go index 8a2752afb2c..8cbcccb2919 100644 --- a/process/track/export_test.go +++ b/process/track/export_test.go @@ -11,70 +11,86 @@ import ( // shardBlockTrack +// SetNumPendingMiniBlocks - func (sbt *shardBlockTrack) SetNumPendingMiniBlocks(shardID uint32, numPendingMiniBlocks uint32) { sbt.blockBalancer.SetNumPendingMiniBlocks(shardID, numPendingMiniBlocks) } +// GetNumPendingMiniBlocks - func (sbt *shardBlockTrack) GetNumPendingMiniBlocks(shardID uint32) uint32 { return sbt.blockBalancer.GetNumPendingMiniBlocks(shardID) } +// SetLastShardProcessedMetaNonce - func (sbt *shardBlockTrack) SetLastShardProcessedMetaNonce(shardID uint32, nonce uint64) { sbt.blockBalancer.SetLastShardProcessedMetaNonce(shardID, nonce) } +// GetLastShardProcessedMetaNonce - func (sbt *shardBlockTrack) GetLastShardProcessedMetaNonce(shardID uint32) uint64 { return sbt.blockBalancer.GetLastShardProcessedMetaNonce(shardID) } +// GetTrackedShardHeaderWithNonceAndHash - func (sbt *shardBlockTrack) GetTrackedShardHeaderWithNonceAndHash(shardID uint32, nonce uint64, hash []byte) (data.HeaderHandler, error) { return sbt.getTrackedShardHeaderWithNonceAndHash(shardID, nonce, hash) } // metaBlockTrack +// GetTrackedMetaBlockWithHash - func (mbt *metaBlockTrack) GetTrackedMetaBlockWithHash(hash []byte) (*block.MetaBlock, error) { return mbt.getTrackedMetaBlockWithHash(hash) } // baseBlockTrack +// ReceivedHeader - func (bbt *baseBlockTrack) ReceivedHeader(headerHandler data.HeaderHandler, headerHash []byte) { bbt.receivedHeader(headerHandler, headerHash) } +// CheckTrackerNilParameters - func CheckTrackerNilParameters(arguments ArgBaseTracker) error { return checkTrackerNilParameters(arguments) } +// InitNotarizedHeaders - func (bbt *baseBlockTrack) InitNotarizedHeaders(startHeaders map[uint32]data.HeaderHandler) error { return bbt.initNotarizedHeaders(startHeaders) } +// ReceivedShardHeader - func (bbt *baseBlockTrack) ReceivedShardHeader(headerHandler data.HeaderHandler, shardHeaderHash []byte) { bbt.receivedShardHeader(headerHandler, shardHeaderHash) } +// ReceivedMetaBlock - func (bbt *baseBlockTrack) ReceivedMetaBlock(headerHandler data.HeaderHandler, metaBlockHash []byte) { bbt.receivedMetaBlock(headerHandler, metaBlockHash) } +// GetMaxNumHeadersToKeepPerShard - func (bbt *baseBlockTrack) GetMaxNumHeadersToKeepPerShard() int { return bbt.maxNumHeadersToKeepPerShard } +// ShouldAddHeaderForCrossShard - func (bbt *baseBlockTrack) ShouldAddHeaderForCrossShard(headerHandler data.HeaderHandler) bool { return bbt.shouldAddHeaderForShard(headerHandler, bbt.crossNotarizer, headerHandler.GetShardID()) } +// ShouldAddHeaderForSelfShard - func (bbt *baseBlockTrack) ShouldAddHeaderForSelfShard(headerHandler data.HeaderHandler) bool { return bbt.shouldAddHeaderForShard(headerHandler, bbt.selfNotarizer, core.MetachainShardId) } +// AddHeader - func (bbt *baseBlockTrack) AddHeader(header data.HeaderHandler, hash []byte) bool { return bbt.addHeader(header, hash) } +// AppendTrackedHeader - func (bbt *baseBlockTrack) AppendTrackedHeader(headerHandler data.HeaderHandler) { bbt.mutHeaders.Lock() headersForShard, ok := bbt.headers[headerHandler.GetShardID()] @@ -87,48 +103,59 @@ func (bbt *baseBlockTrack) AppendTrackedHeader(headerHandler data.HeaderHandler) bbt.mutHeaders.Unlock() } +// CleanupTrackedHeadersBehindNonce - func (bbt *baseBlockTrack) CleanupTrackedHeadersBehindNonce(shardID uint32, nonce uint64) { bbt.cleanupTrackedHeadersBehindNonce(shardID, nonce) } +// DisplayTrackedHeadersForShard - func (bbt *baseBlockTrack) DisplayTrackedHeadersForShard(shardID uint32, message string) { bbt.displayTrackedHeadersForShard(shardID, message) } +// SetRoundHandler - func (bbt *baseBlockTrack) SetRoundHandler(roundHandler process.RoundHandler) { bbt.roundHandler = roundHandler } +// SetCrossNotarizer - func (bbt *baseBlockTrack) SetCrossNotarizer(notarizer blockNotarizerHandler) { bbt.crossNotarizer = notarizer } +// SetSelfNotarizer - func (bbt *baseBlockTrack) SetSelfNotarizer(notarizer blockNotarizerHandler) { bbt.selfNotarizer = notarizer } +// SetShardCoordinator - func (bbt *baseBlockTrack) SetShardCoordinator(coordinator sharding.Coordinator) { bbt.shardCoordinator = coordinator } +// NewBaseBlockTrack - func NewBaseBlockTrack() *baseBlockTrack { return &baseBlockTrack{} } +// DoWhitelistWithMetaBlockIfNeeded - func (bbt *baseBlockTrack) DoWhitelistWithMetaBlockIfNeeded(metaBlock *block.MetaBlock) { bbt.doWhitelistWithMetaBlockIfNeeded(metaBlock) } +// DoWhitelistWithShardHeaderIfNeeded - func (bbt *baseBlockTrack) DoWhitelistWithShardHeaderIfNeeded(shardHeader *block.Header) { bbt.doWhitelistWithShardHeaderIfNeeded(shardHeader) } +// IsHeaderOutOfRange - func (bbt *baseBlockTrack) IsHeaderOutOfRange(headerHandler data.HeaderHandler) bool { return bbt.isHeaderOutOfRange(headerHandler) } // blockNotifier +// GetNotarizedHeadersHandlers - func (bn *blockNotifier) GetNotarizedHeadersHandlers() []func(shardID uint32, headers []data.HeaderHandler, headersHashes [][]byte) { bn.mutNotarizedHeadersHandlers.RLock() notarizedHeadersHandlers := bn.notarizedHeadersHandlers @@ -139,12 +166,14 @@ func (bn *blockNotifier) GetNotarizedHeadersHandlers() []func(shardID uint32, he // blockNotarizer +// AppendNotarizedHeader - func (bn *blockNotarizer) AppendNotarizedHeader(headerHandler data.HeaderHandler) { bn.mutNotarizedHeaders.Lock() bn.notarizedHeaders[headerHandler.GetShardID()] = append(bn.notarizedHeaders[headerHandler.GetShardID()], &HeaderInfo{Header: headerHandler}) bn.mutNotarizedHeaders.Unlock() } +// GetNotarizedHeaders - func (bn *blockNotarizer) GetNotarizedHeaders() map[uint32][]*HeaderInfo { bn.mutNotarizedHeaders.RLock() notarizedHeaders := bn.notarizedHeaders @@ -153,6 +182,7 @@ func (bn *blockNotarizer) GetNotarizedHeaders() map[uint32][]*HeaderInfo { return notarizedHeaders } +// GetNotarizedHeaderWithIndex - func (bn *blockNotarizer) GetNotarizedHeaderWithIndex(shardID uint32, index int) data.HeaderHandler { bn.mutNotarizedHeaders.RLock() notarizedHeader := bn.notarizedHeaders[shardID][index].Header @@ -161,70 +191,98 @@ func (bn *blockNotarizer) GetNotarizedHeaderWithIndex(shardID uint32, index int) return notarizedHeader } +// LastNotarizedHeaderInfo - func (bn *blockNotarizer) LastNotarizedHeaderInfo(shardID uint32) *HeaderInfo { return bn.lastNotarizedHeaderInfo(shardID) } // blockProcessor +// DoJobOnReceivedHeader - func (bp *blockProcessor) DoJobOnReceivedHeader(shardID uint32) { bp.doJobOnReceivedHeader(shardID) } +// DoJobOnReceivedCrossNotarizedHeader - func (bp *blockProcessor) DoJobOnReceivedCrossNotarizedHeader(shardID uint32) { bp.doJobOnReceivedCrossNotarizedHeader(shardID) } +// ComputeLongestChainFromLastCrossNotarized - func (bp *blockProcessor) ComputeLongestChainFromLastCrossNotarized(shardID uint32) (data.HeaderHandler, []byte, []data.HeaderHandler, [][]byte) { return bp.computeLongestChainFromLastCrossNotarized(shardID) } +// ComputeSelfNotarizedHeaders - func (bp *blockProcessor) ComputeSelfNotarizedHeaders(headers []data.HeaderHandler) ([]data.HeaderHandler, [][]byte) { return bp.computeSelfNotarizedHeaders(headers) } -func (bp *blockProcessor) GetNextHeader(longestChainHeadersIndexes *[]int, headersIndexes []int, prevHeader data.HeaderHandler, sortedHeaders []data.HeaderHandler, index int) { - bp.getNextHeader(longestChainHeadersIndexes, headersIndexes, prevHeader, sortedHeaders, index) +// GetNextHeader - +func (bp *blockProcessor) GetNextHeader( + longestChainHeadersIndexes *[]int, + headersIndexes []int, + prevHeader data.HeaderHandler, + sortedHeaders []data.HeaderHandler, + sortedHashes [][]byte, + index int, +) { + bp.getNextHeader(longestChainHeadersIndexes, headersIndexes, prevHeader, sortedHeaders, sortedHashes, index) } -func (bp *blockProcessor) CheckHeaderFinality(header data.HeaderHandler, sortedHeaders []data.HeaderHandler, index int) error { - return bp.checkHeaderFinality(header, sortedHeaders, index) +// CheckHeaderFinality - +func (bp *blockProcessor) CheckHeaderFinality( + header data.HeaderHandler, + sortedHeaders []data.HeaderHandler, + sortedHashes [][]byte, + index int, +) error { + return bp.checkHeaderFinality(header, sortedHeaders, sortedHashes, index) } +// RequestHeadersIfNeeded - func (bp *blockProcessor) RequestHeadersIfNeeded(lastNotarizedHeader data.HeaderHandler, sortedHeaders []data.HeaderHandler, longestChainHeaders []data.HeaderHandler) { bp.requestHeadersIfNeeded(lastNotarizedHeader, sortedHeaders, longestChainHeaders) } +// GetLatestValidHeader - func (bp *blockProcessor) GetLatestValidHeader(lastNotarizedHeader data.HeaderHandler, longestChainHeaders []data.HeaderHandler) data.HeaderHandler { return bp.getLatestValidHeader(lastNotarizedHeader, longestChainHeaders) } +// GetHighestRoundInReceivedHeaders - func (bp *blockProcessor) GetHighestRoundInReceivedHeaders(latestValidHeader data.HeaderHandler, sortedReceivedHeaders []data.HeaderHandler) uint64 { return bp.getHighestRoundInReceivedHeaders(latestValidHeader, sortedReceivedHeaders) } +// RequestHeadersIfNothingNewIsReceived - func (bp *blockProcessor) RequestHeadersIfNothingNewIsReceived(lastNotarizedHeaderNonce uint64, latestValidHeader data.HeaderHandler, highestRoundInReceivedHeaders uint64) { bp.requestHeadersIfNothingNewIsReceived(lastNotarizedHeaderNonce, latestValidHeader, highestRoundInReceivedHeaders) } +// RequestHeaders - func (bp *blockProcessor) RequestHeaders(shardID uint32, fromNonce uint64) { bp.requestHeaders(shardID, fromNonce) } +// ShouldProcessReceivedHeader - func (bp *blockProcessor) ShouldProcessReceivedHeader(headerHandler data.HeaderHandler) bool { return bp.shouldProcessReceivedHeader(headerHandler) } // miniBlockTrack +// ReceivedMiniBlock - func (mbt *miniBlockTrack) ReceivedMiniBlock(key []byte, value interface{}) { mbt.receivedMiniBlock(key, value) } +// GetTransactionPool - func (mbt *miniBlockTrack) GetTransactionPool(mbType block.Type) dataRetriever.ShardedDataCacherNotifier { return mbt.getTransactionPool(mbType) } +// SetBlockTransactionsPool - func (mbt *miniBlockTrack) SetBlockTransactionsPool(blockTransactionsPool dataRetriever.ShardedDataCacherNotifier) { mbt.blockTransactionsPool = blockTransactionsPool } diff --git a/process/track/metaBlockTrack.go b/process/track/metaBlockTrack.go index 26e13d58e1c..b15ed33b383 100644 --- a/process/track/metaBlockTrack.go +++ b/process/track/metaBlockTrack.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/process" ) @@ -46,6 +47,12 @@ func NewMetaBlockTrack(arguments ArgMetaTracker) (*metaBlockTrack, error) { SelfNotarizedHeadersNotifier: bbt.selfNotarizedHeadersNotifier, FinalMetachainHeadersNotifier: bbt.finalMetachainHeadersNotifier, RoundHandler: arguments.RoundHandler, + EnableEpochsHandler: arguments.EnableEpochsHandler, + ProofsPool: arguments.ProofsPool, + Marshaller: arguments.Marshalizer, + Hasher: arguments.Hasher, + HeadersPool: arguments.PoolsHolder.Headers(), + IsImportDBMode: arguments.IsImportDBMode, } blockProcessorObject, err := NewBlockProcessor(argBlockProcessor) @@ -56,6 +63,7 @@ func NewMetaBlockTrack(arguments ArgMetaTracker) (*metaBlockTrack, error) { mbt.blockProcessor = blockProcessorObject mbt.headers = make(map[uint32]map[uint64][]*HeaderInfo) mbt.headersPool.RegisterHandler(mbt.receivedHeader) + mbt.proofsPool.RegisterHandler(mbt.receivedProof) mbt.headersPool.Clear() return &mbt, nil @@ -141,7 +149,13 @@ func (mbt *metaBlockTrack) removeInvalidShardHeadersDueToEpochChange( for _, headerInfo := range headersInfo { round := headerInfo.Header.GetRound() epoch := headerInfo.Header.GetEpoch() - isInvalidHeader := round > metaRoundAttestingEpoch+process.EpochChangeGracePeriod && epoch < metaNewEpoch + gracePeriod, err := mbt.epochChangeGracePeriodHandler.GetGracePeriodForEpoch(metaNewEpoch) + if err != nil { + log.Warn("get grace period for epoch", "error", err.Error()) + continue + } + + isInvalidHeader := round > metaRoundAttestingEpoch+uint64(gracePeriod) && epoch < metaNewEpoch if !isInvalidHeader { newHeadersInfo = append(newHeadersInfo, headerInfo) } diff --git a/process/track/miniBlockTrack_test.go b/process/track/miniBlockTrack_test.go index 123c3813052..6a72d7ad9d0 100644 --- a/process/track/miniBlockTrack_test.go +++ b/process/track/miniBlockTrack_test.go @@ -4,14 +4,16 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/track" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" - "github.com/stretchr/testify/assert" ) func TestNewMiniBlockTrack_NilDataPoolHolderErr(t *testing.T) { @@ -256,7 +258,7 @@ func TestGetTransactionPool_ShouldWork(t *testing.T) { return unsignedTransactionsPool }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, } mbt, _ := track.NewMiniBlockTrack(dataPool, mock.NewMultipleShardsCoordinatorMock(), &testscommon.WhiteListHandlerStub{}) @@ -286,7 +288,7 @@ func createDataPool() dataRetriever.PoolsHolder { return testscommon.NewShardedDataStub() }, MiniBlocksCalled: func() storage.Cacher { - return testscommon.NewCacherStub() + return cache.NewCacherStub() }, } } diff --git a/process/track/shardBlockTrack.go b/process/track/shardBlockTrack.go index 327282725bc..7a5617e7896 100644 --- a/process/track/shardBlockTrack.go +++ b/process/track/shardBlockTrack.go @@ -6,6 +6,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/process" ) @@ -46,6 +47,12 @@ func NewShardBlockTrack(arguments ArgShardTracker) (*shardBlockTrack, error) { SelfNotarizedHeadersNotifier: bbt.selfNotarizedHeadersNotifier, FinalMetachainHeadersNotifier: bbt.finalMetachainHeadersNotifier, RoundHandler: arguments.RoundHandler, + EnableEpochsHandler: arguments.EnableEpochsHandler, + ProofsPool: arguments.ProofsPool, + Marshaller: arguments.Marshalizer, + Hasher: arguments.Hasher, + HeadersPool: arguments.PoolsHolder.Headers(), + IsImportDBMode: arguments.IsImportDBMode, } blockProcessorObject, err := NewBlockProcessor(argBlockProcessor) @@ -56,6 +63,7 @@ func NewShardBlockTrack(arguments ArgShardTracker) (*shardBlockTrack, error) { sbt.blockProcessor = blockProcessorObject sbt.headers = make(map[uint32]map[uint64][]*HeaderInfo) sbt.headersPool.RegisterHandler(sbt.receivedHeader) + sbt.proofsPool.RegisterHandler(sbt.receivedProof) sbt.headersPool.Clear() return &sbt, nil diff --git a/process/transaction/interceptedTransaction_test.go b/process/transaction/interceptedTransaction_test.go index 4c888f60863..a526d91feb5 100644 --- a/process/transaction/interceptedTransaction_test.go +++ b/process/transaction/interceptedTransaction_test.go @@ -15,6 +15,10 @@ import ( "github.com/multiversx/mx-chain-core-go/data" dataTransaction "github.com/multiversx/mx-chain-core-go/data/transaction" "github.com/multiversx/mx-chain-crypto-go" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/interceptors" @@ -22,13 +26,11 @@ import ( "github.com/multiversx/mx-chain-go/process/smartContract" "github.com/multiversx/mx-chain-go/process/transaction" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/economicsmocks" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) var errSingleSignKeyGenMock = errors.New("errSingleSignKeyGenMock") @@ -1365,7 +1367,7 @@ func TestInterceptedTransaction_CheckValiditySecondTimeDoesNotVerifySig(t *testi return shardCoordinator.CurrentShard } - cache := testscommon.NewCacherMock() + cache := cache.NewCacherMock() whiteListerVerifiedTxs, err := interceptors.NewWhiteListDataVerifier(cache) require.Nil(t, err) @@ -1645,7 +1647,7 @@ func TestRelayTransaction_NotAddedToWhitelistUntilIntegrityChecked(t *testing.T) t.Parallel() marshalizer := &mock.MarshalizerMock{} - whiteListHandler, _ := interceptors.NewWhiteListDataVerifier(testscommon.NewCacherMock()) + whiteListHandler, _ := interceptors.NewWhiteListDataVerifier(cache.NewCacherMock()) userTx := &dataTransaction.Transaction{ SndAddr: recvAddress, diff --git a/process/transactionEvaluator/simulationAccountsDB_test.go b/process/transactionEvaluator/simulationAccountsDB_test.go index 13655ba315f..16c9e7effdd 100644 --- a/process/transactionEvaluator/simulationAccountsDB_test.go +++ b/process/transactionEvaluator/simulationAccountsDB_test.go @@ -37,36 +37,36 @@ func TestReadOnlyAccountsDB_WriteOperationsShouldNotCalled(t *testing.T) { failErrMsg := "this function should have not be called" accDb := &stateMock.AccountsStub{ SaveAccountCalled: func(account vmcommon.AccountHandler) error { - t.Errorf(failErrMsg) + t.Errorf("%s", failErrMsg) return nil }, RemoveAccountCalled: func(_ []byte) error { - t.Errorf(failErrMsg) + t.Errorf("%s", failErrMsg) return nil }, CommitCalled: func() ([]byte, error) { - t.Errorf(failErrMsg) + t.Errorf("%s", failErrMsg) return nil, nil }, RevertToSnapshotCalled: func(_ int) error { - t.Errorf(failErrMsg) + t.Errorf("%s", failErrMsg) return nil }, RecreateTrieCalled: func(_ common.RootHashHolder) error { - t.Errorf(failErrMsg) + t.Errorf("%s", failErrMsg) return nil }, PruneTrieCalled: func(_ []byte, _ state.TriePruningIdentifier, _ state.PruningHandler) { - t.Errorf(failErrMsg) + t.Errorf("%s", failErrMsg) }, CancelPruneCalled: func(_ []byte, _ state.TriePruningIdentifier) { - t.Errorf(failErrMsg) + t.Errorf("%s", failErrMsg) }, SnapshotStateCalled: func(_ []byte, _ uint32) { - t.Errorf(failErrMsg) + t.Errorf("%s", failErrMsg) }, RecreateAllTriesCalled: func(_ []byte) (map[string]common.Trie, error) { - t.Errorf(failErrMsg) + t.Errorf("%s", failErrMsg) return nil, nil }, } diff --git a/process/txstatus/txStatusComputer.go b/process/txstatus/txStatusComputer.go index 74b98ad8ffd..1eec437776e 100644 --- a/process/txstatus/txStatusComputer.go +++ b/process/txstatus/txStatusComputer.go @@ -120,7 +120,7 @@ func (sc *statusComputer) SetStatusIfIsRewardReverted( if selfShardID == core.MetachainShardId { storerUnit = dataRetriever.MetaHdrNonceHashDataUnit } else { - storerUnit = dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(selfShardID) + storerUnit = dataRetriever.GetHdrNonceHashDataUnit(selfShardID) } nonceToByteSlice := sc.uint64ByteSliceConverter.ToByteSlice(headerNonce) diff --git a/process/unsigned/interceptedUnsignedTransaction_test.go b/process/unsigned/interceptedUnsignedTransaction_test.go index b0c00e4982e..102b76c0975 100644 --- a/process/unsigned/interceptedUnsignedTransaction_test.go +++ b/process/unsigned/interceptedUnsignedTransaction_test.go @@ -11,13 +11,14 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/smartContractResult" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/process/mock" "github.com/multiversx/mx-chain-go/process/unsigned" "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" - logger "github.com/multiversx/mx-chain-logger-go" - "github.com/stretchr/testify/assert" ) var senderShard = uint32(2) @@ -170,7 +171,7 @@ func TestNewInterceptedUnsignedTransaction_ShouldWork(t *testing.T) { assert.Nil(t, err) } -// ------- CheckValidity +// ------- Verify func TestInterceptedUnsignedTransaction_CheckValidityNilTxHashShouldErr(t *testing.T) { t.Parallel() diff --git a/scripts/generators/mockGenerator.sh b/scripts/generators/mockGenerator.sh index 71e120c4c52..e48163b158b 100755 --- a/scripts/generators/mockGenerator.sh +++ b/scripts/generators/mockGenerator.sh @@ -240,7 +240,7 @@ extractDefaultReturn() { writeWithReturn() { { echo "return mock.$mockField($stringParamNames)"; - echo "}"; + echo -e "}\n"; } >> "$mockPath" # compute default values to return when mock member is not provided, separated by comma diff --git a/scripts/testnet/include/config.sh b/scripts/testnet/include/config.sh index 25f836a84b7..16ed10fba39 100644 --- a/scripts/testnet/include/config.sh +++ b/scripts/testnet/include/config.sh @@ -135,10 +135,25 @@ updateNodeConfig() { updateConfigsForStakingV4 + # Update chain parameters + updateChainParameters config_observer.toml + updateChainParameters config_validator.toml + echo "Updated configuration for Nodes." popd } +updateChainParameters() { + tomlFile=$1 + + sed -i "s,ShardConsensusGroupSize\([^,]*\),ShardConsensusGroupSize = $SHARD_CONSENSUS_SIZE," $tomlFile + sed -i "s,ShardMinNumNodes\([^,]*\),ShardMinNumNodes = $SHARD_CONSENSUS_SIZE," $tomlFile + sed -i "s,MetachainConsensusGroupSize\([^,]*\),MetachainConsensusGroupSize = $META_CONSENSUS_SIZE," $tomlFile + sed -i "s,MetachainMinNumNodes\([^,]*\),MetachainMinNumNodes = $META_CONSENSUS_SIZE," $tomlFile + sed -i "s,RoundDuration\([^,]*\),RoundDuration = $ROUND_DURATION_IN_MS," $tomlFile + sed -i "s,Hysteresis\([^,]*\),Hysteresis = $HYSTERESIS," $tomlFile +} + updateConfigsForStakingV4() { config=$(cat enableEpochs.toml) diff --git a/scripts/testnet/include/validators.sh b/scripts/testnet/include/validators.sh index b19bad12525..0999a3e4af3 100644 --- a/scripts/testnet/include/validators.sh +++ b/scripts/testnet/include/validators.sh @@ -85,9 +85,9 @@ startSingleValidator() { local startCommand="" if [ "$NODE_WATCHER" -eq 1 ]; then setWorkdirForNextCommands "$TESTNETDIR/node_working_dirs/$DIR_NAME$VALIDATOR_INDEX" - startCommand="$(assembleCommand_startValidatorNodeWithWatcher $VALIDATOR_INDEX $DIR_NAME)" + startCommand="$(assembleCommand_startValidatorNodeWithWatcher $VALIDATOR_INDEX $DIR_NAME $SHARD)" else - startCommand="$(assembleCommand_startValidatorNode $VALIDATOR_INDEX $DIR_NAME)" + startCommand="$(assembleCommand_startValidatorNode $VALIDATOR_INDEX $DIR_NAME $SHARD)" fi runCommandInTerminal "$startCommand" } @@ -129,12 +129,13 @@ stopSingleValidator() { assembleCommand_startValidatorNodeWithWatcher() { VALIDATOR_INDEX=$1 DIR_NAME=$2 + SHARD=$3 (( PORT=$PORT_ORIGIN_VALIDATOR + $VALIDATOR_INDEX )) WORKING_DIR=$TESTNETDIR/node_working_dirs/$DIR_NAME$VALIDATOR_INDEX local source_command="source $MULTIVERSXTESTNETSCRIPTSDIR/include/watcher.sh" local watcher_command="node-start-with-watcher $VALIDATOR_INDEX $PORT &" - local node_command=$(assembleCommand_startValidatorNode $VALIDATOR_INDEX $DIR_NAME) + local node_command=$(assembleCommand_startValidatorNode $VALIDATOR_INDEX $DIR_NAME $SHARD) mkdir -p $WORKING_DIR echo "$node_command" > $WORKING_DIR/node-command echo "$PORT" > $WORKING_DIR/node-port @@ -145,6 +146,7 @@ assembleCommand_startValidatorNodeWithWatcher() { assembleCommand_startValidatorNode() { VALIDATOR_INDEX=$1 DIR_NAME=$2 + SHARD=$3 (( PORT=$PORT_ORIGIN_VALIDATOR + $VALIDATOR_INDEX )) (( RESTAPIPORT=$PORT_ORIGIN_VALIDATOR_REST + $VALIDATOR_INDEX )) (( KEY_INDEX=$VALIDATOR_INDEX )) @@ -155,6 +157,10 @@ assembleCommand_startValidatorNode() { -sk-index $KEY_INDEX \ -working-directory $WORKING_DIR -config ./config/config_validator.toml" + if [[ $MULTI_KEY_NODES -eq 1 ]]; then + node_command="$node_command --destination-shard-as-observer $SHARD" + fi + if [ -n "$NODE_NICENESS" ] then node_command="nice -n $NODE_NICENESS $node_command" diff --git a/sharding/chainParametersHolder.go b/sharding/chainParametersHolder.go new file mode 100644 index 00000000000..341460d2dbd --- /dev/null +++ b/sharding/chainParametersHolder.go @@ -0,0 +1,187 @@ +package sharding + +import ( + "fmt" + "sort" + "strings" + "sync" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/config" +) + +type chainParametersHolder struct { + currentChainParameters config.ChainParametersByEpochConfig + chainParameters []config.ChainParametersByEpochConfig + chainParametersNotifier ChainParametersNotifierHandler + mutOperations sync.RWMutex +} + +// ArgsChainParametersHolder holds the arguments needed for creating a new chainParametersHolder +type ArgsChainParametersHolder struct { + EpochStartEventNotifier EpochStartEventNotifier + ChainParameters []config.ChainParametersByEpochConfig + ChainParametersNotifier ChainParametersNotifierHandler +} + +// NewChainParametersHolder returns a new instance of chainParametersHolder +func NewChainParametersHolder(args ArgsChainParametersHolder) (*chainParametersHolder, error) { + err := validateArgs(args) + if err != nil { + return nil, err + } + + chainParameters := args.ChainParameters + // sort the config values in descending order + sort.SliceStable(chainParameters, func(i, j int) bool { + return chainParameters[i].EnableEpoch > chainParameters[j].EnableEpoch + }) + + earliestChainParams := chainParameters[len(chainParameters)-1] + if earliestChainParams.EnableEpoch != 0 { + return nil, ErrMissingConfigurationForEpochZero + } + + paramsHolder := &chainParametersHolder{ + currentChainParameters: earliestChainParams, // will be updated on the epoch notifier handlers + chainParameters: args.ChainParameters, + chainParametersNotifier: args.ChainParametersNotifier, + } + args.ChainParametersNotifier.UpdateCurrentChainParameters(earliestChainParams) + args.EpochStartEventNotifier.RegisterHandler(paramsHolder) + + logInitialConfiguration(args.ChainParameters) + + return paramsHolder, nil +} + +func logInitialConfiguration(chainParameters []config.ChainParametersByEpochConfig) { + logMessage := "initialized chainParametersHolder with the values:\n" + logLines := make([]string, 0, len(chainParameters)) + for _, params := range chainParameters { + logLines = append(logLines, fmt.Sprintf("\tenable epoch=%d, round duration=%d, hysteresis=%.2f, shard consensus group size=%d, shard min nodes=%d, meta consensus group size=%d, meta min nodes=%d, adaptivity=%v", + params.EnableEpoch, params.RoundDuration, params.Hysteresis, params.ShardConsensusGroupSize, params.ShardMinNumNodes, params.MetachainConsensusGroupSize, params.MetachainMinNumNodes, params.Adaptivity)) + } + + logMessage += strings.Join(logLines, "\n") + log.Debug(logMessage) +} + +func validateArgs(args ArgsChainParametersHolder) error { + if check.IfNil(args.EpochStartEventNotifier) { + return ErrNilEpochStartEventNotifier + } + if len(args.ChainParameters) == 0 { + return ErrMissingChainParameters + } + if check.IfNil(args.ChainParametersNotifier) { + return ErrNilChainParametersNotifier + } + return validateChainParameters(args.ChainParameters) +} + +func validateChainParameters(chainParametersConfig []config.ChainParametersByEpochConfig) error { + for idx, chainParameters := range chainParametersConfig { + if chainParameters.ShardConsensusGroupSize < 1 { + return fmt.Errorf("%w for chain parameters with index %d", ErrNegativeOrZeroConsensusGroupSize, idx) + } + if chainParameters.ShardMinNumNodes < chainParameters.ShardConsensusGroupSize { + return fmt.Errorf("%w for chain parameters with index %d", ErrMinNodesPerShardSmallerThanConsensusSize, idx) + } + if chainParameters.MetachainConsensusGroupSize < 1 { + return fmt.Errorf("%w for chain parameters with index %d", ErrNegativeOrZeroConsensusGroupSize, idx) + } + if chainParameters.MetachainMinNumNodes < chainParameters.MetachainConsensusGroupSize { + return fmt.Errorf("%w for chain parameters with index %d", ErrMinNodesPerShardSmallerThanConsensusSize, idx) + } + } + + return nil +} + +// EpochStartAction is called when a new epoch is confirmed +func (c *chainParametersHolder) EpochStartAction(header data.HeaderHandler) { + c.handleEpochChange(header.GetEpoch()) +} + +// EpochStartPrepare is called when a new epoch is observed, but not yet confirmed. No action is required on this component +func (c *chainParametersHolder) EpochStartPrepare(_ data.HeaderHandler, _ data.BodyHandler) { +} + +// NotifyOrder returns the notification order for a start of epoch event +func (c *chainParametersHolder) NotifyOrder() uint32 { + return common.ChainParametersOrder +} + +func (c *chainParametersHolder) handleEpochChange(epoch uint32) { + c.mutOperations.Lock() + defer c.mutOperations.Unlock() + + matchingVersionForNewEpoch, err := getMatchingChainParametersUnprotected(epoch, c.chainParameters) + if err != nil { + log.Error("chainParametersHolder.EpochConfirmed: cannot get matching chain parameters", "epoch", epoch, "error", err) + return + } + if matchingVersionForNewEpoch.EnableEpoch == c.currentChainParameters.EnableEpoch { + return + } + + c.currentChainParameters = matchingVersionForNewEpoch + log.Debug("updated chainParametersHolder current chain parameters", + "round duration", matchingVersionForNewEpoch.RoundDuration, + "shard consensus group size", matchingVersionForNewEpoch.ShardConsensusGroupSize, + "shard min num nodes", matchingVersionForNewEpoch.ShardMinNumNodes, + "metachain consensus group size", matchingVersionForNewEpoch.MetachainConsensusGroupSize, + "metachain min num nodes", matchingVersionForNewEpoch.MetachainMinNumNodes, + "shard consensus group size", matchingVersionForNewEpoch.ShardConsensusGroupSize, + "hysteresis", matchingVersionForNewEpoch.Hysteresis, + "adaptivity", matchingVersionForNewEpoch.Adaptivity, + ) + c.chainParametersNotifier.UpdateCurrentChainParameters(matchingVersionForNewEpoch) +} + +// CurrentChainParameters will return the chain parameters that are active at the moment of calling +func (c *chainParametersHolder) CurrentChainParameters() config.ChainParametersByEpochConfig { + c.mutOperations.RLock() + defer c.mutOperations.RUnlock() + + return c.currentChainParameters +} + +// AllChainParameters will return the entire slice of chain parameters configuration +func (c *chainParametersHolder) AllChainParameters() []config.ChainParametersByEpochConfig { + c.mutOperations.RLock() + defer c.mutOperations.RUnlock() + + chainParametersCopy := make([]config.ChainParametersByEpochConfig, len(c.chainParameters)) + copy(chainParametersCopy, c.chainParameters) + + return chainParametersCopy +} + +// ChainParametersForEpoch will return the corresponding chain parameters for the provided epoch +func (c *chainParametersHolder) ChainParametersForEpoch(epoch uint32) (config.ChainParametersByEpochConfig, error) { + c.mutOperations.RLock() + defer c.mutOperations.RUnlock() + + return getMatchingChainParametersUnprotected(epoch, c.chainParameters) +} + +func getMatchingChainParametersUnprotected(epoch uint32, configValues []config.ChainParametersByEpochConfig) (config.ChainParametersByEpochConfig, error) { + for _, chainParams := range configValues { + if chainParams.EnableEpoch <= epoch { + return chainParams, nil + } + } + + // should never reach this code, as the config values are checked on the constructor + return config.ChainParametersByEpochConfig{}, ErrNoMatchingConfigurationFound +} + +// IsInterfaceNil returns true if there is no value under the interface +func (c *chainParametersHolder) IsInterfaceNil() bool { + return c == nil +} diff --git a/sharding/chainParametersHolder_test.go b/sharding/chainParametersHolder_test.go new file mode 100644 index 00000000000..7ec5876cc7d --- /dev/null +++ b/sharding/chainParametersHolder_test.go @@ -0,0 +1,386 @@ +package sharding + +import ( + "fmt" + "sync" + "testing" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/data/block" + + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/testscommon/commonmocks" + mock "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" + + "github.com/stretchr/testify/require" +) + +func TestNewChainParametersHolder(t *testing.T) { + t.Parallel() + + getDummyArgs := func() ArgsChainParametersHolder { + return ArgsChainParametersHolder{ + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, + ChainParameters: []config.ChainParametersByEpochConfig{ + { + EnableEpoch: 0, + ShardMinNumNodes: 5, + ShardConsensusGroupSize: 5, + MetachainMinNumNodes: 7, + MetachainConsensusGroupSize: 7, + RoundDuration: 4000, + Hysteresis: 0.2, + Adaptivity: false, + }, + }, + ChainParametersNotifier: &commonmocks.ChainParametersNotifierStub{}, + } + } + + t.Run("nil epoch start event notifier", func(t *testing.T) { + t.Parallel() + + args := getDummyArgs() + args.EpochStartEventNotifier = nil + + paramsHolder, err := NewChainParametersHolder(args) + require.True(t, check.IfNil(paramsHolder)) + require.Equal(t, ErrNilEpochStartEventNotifier, err) + }) + + t.Run("empty chain parameters", func(t *testing.T) { + t.Parallel() + + args := getDummyArgs() + args.ChainParameters = nil + + paramsHolder, err := NewChainParametersHolder(args) + require.True(t, check.IfNil(paramsHolder)) + require.Equal(t, ErrMissingChainParameters, err) + }) + + t.Run("invalid shard consensus size", func(t *testing.T) { + t.Parallel() + + args := getDummyArgs() + args.ChainParameters[0].ShardConsensusGroupSize = 0 + + paramsHolder, err := NewChainParametersHolder(args) + require.True(t, check.IfNil(paramsHolder)) + require.ErrorIs(t, err, ErrNegativeOrZeroConsensusGroupSize) + }) + + t.Run("min nodes per shard smaller than consensus size", func(t *testing.T) { + t.Parallel() + + args := getDummyArgs() + args.ChainParameters[0].ShardConsensusGroupSize = 5 + args.ChainParameters[0].ShardMinNumNodes = 4 + + paramsHolder, err := NewChainParametersHolder(args) + require.True(t, check.IfNil(paramsHolder)) + require.ErrorIs(t, err, ErrMinNodesPerShardSmallerThanConsensusSize) + }) + + t.Run("invalid metachain consensus size", func(t *testing.T) { + t.Parallel() + + args := getDummyArgs() + args.ChainParameters[0].MetachainConsensusGroupSize = 0 + + paramsHolder, err := NewChainParametersHolder(args) + require.True(t, check.IfNil(paramsHolder)) + require.ErrorIs(t, err, ErrNegativeOrZeroConsensusGroupSize) + }) + + t.Run("min nodes meta smaller than consensus size", func(t *testing.T) { + t.Parallel() + + args := getDummyArgs() + args.ChainParameters[0].MetachainConsensusGroupSize = 5 + args.ChainParameters[0].MetachainMinNumNodes = 4 + + paramsHolder, err := NewChainParametersHolder(args) + require.True(t, check.IfNil(paramsHolder)) + require.ErrorIs(t, err, ErrMinNodesPerShardSmallerThanConsensusSize) + }) + + t.Run("invalid future chain parameters", func(t *testing.T) { + t.Parallel() + + args := getDummyArgs() + newChainParameters := args.ChainParameters[0] + newChainParameters.ShardConsensusGroupSize = 0 + args.ChainParameters = append(args.ChainParameters, newChainParameters) + + paramsHolder, err := NewChainParametersHolder(args) + require.True(t, check.IfNil(paramsHolder)) + require.ErrorIs(t, err, ErrNegativeOrZeroConsensusGroupSize) + require.Contains(t, err.Error(), "index 1") + }) + + t.Run("no config for epoch 0", func(t *testing.T) { + t.Parallel() + + args := getDummyArgs() + args.ChainParameters[0].EnableEpoch = 37 + paramsHolder, err := NewChainParametersHolder(args) + require.True(t, check.IfNil(paramsHolder)) + require.ErrorIs(t, err, ErrMissingConfigurationForEpochZero) + }) + + t.Run("should work and have the data ready", func(t *testing.T) { + t.Parallel() + + args := getDummyArgs() + secondChainParams := args.ChainParameters[0] + secondChainParams.EnableEpoch = 5 + thirdChainParams := args.ChainParameters[0] + thirdChainParams.EnableEpoch = 10 + args.ChainParameters = append(args.ChainParameters, secondChainParams, thirdChainParams) + + paramsHolder, err := NewChainParametersHolder(args) + require.NoError(t, err) + require.False(t, check.IfNil(paramsHolder)) + + currentValue := paramsHolder.chainParameters[0] + for i := 1; i < len(paramsHolder.chainParameters); i++ { + require.Less(t, paramsHolder.chainParameters[i].EnableEpoch, currentValue.EnableEpoch) + currentValue = paramsHolder.chainParameters[i] + } + + require.Equal(t, uint32(0), paramsHolder.currentChainParameters.EnableEpoch) + }) +} + +func TestChainParametersHolder_EpochStartActionShouldCallTheNotifier(t *testing.T) { + t.Parallel() + + receivedValues := make([]uint32, 0) + notifier := &commonmocks.ChainParametersNotifierStub{ + UpdateCurrentChainParametersCalled: func(params config.ChainParametersByEpochConfig) { + receivedValues = append(receivedValues, params.ShardConsensusGroupSize) + }, + } + paramsHolder, _ := NewChainParametersHolder(ArgsChainParametersHolder{ + ChainParameters: []config.ChainParametersByEpochConfig{ + { + EnableEpoch: 0, + ShardConsensusGroupSize: 5, + ShardMinNumNodes: 7, + MetachainConsensusGroupSize: 7, + MetachainMinNumNodes: 7, + }, + { + EnableEpoch: 5, + ShardConsensusGroupSize: 37, + ShardMinNumNodes: 38, + MetachainConsensusGroupSize: 7, + MetachainMinNumNodes: 7, + }, + }, + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, + ChainParametersNotifier: notifier, + }) + + paramsHolder.EpochStartAction(&block.MetaBlock{Epoch: 5}) + require.Equal(t, []uint32{5, 37}, receivedValues) +} + +func TestChainParametersHolder_ChainParametersForEpoch(t *testing.T) { + t.Parallel() + + t.Run("single configuration, should return it every time", func(t *testing.T) { + t.Parallel() + + params := []config.ChainParametersByEpochConfig{ + { + EnableEpoch: 0, + ShardConsensusGroupSize: 5, + ShardMinNumNodes: 7, + MetachainConsensusGroupSize: 7, + MetachainMinNumNodes: 7, + }, + } + + paramsHolder, _ := NewChainParametersHolder(ArgsChainParametersHolder{ + ChainParameters: params, + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, + ChainParametersNotifier: &commonmocks.ChainParametersNotifierStub{}, + }) + + res, _ := paramsHolder.ChainParametersForEpoch(0) + require.Equal(t, uint32(5), res.ShardConsensusGroupSize) + require.Equal(t, uint32(7), res.MetachainConsensusGroupSize) + + res, _ = paramsHolder.ChainParametersForEpoch(1) + require.Equal(t, uint32(5), res.ShardConsensusGroupSize) + require.Equal(t, uint32(7), res.MetachainConsensusGroupSize) + + res, _ = paramsHolder.ChainParametersForEpoch(3700) + require.Equal(t, uint32(5), res.ShardConsensusGroupSize) + require.Equal(t, uint32(7), res.MetachainConsensusGroupSize) + }) + + t.Run("multiple configurations, should return the corresponding one", func(t *testing.T) { + t.Parallel() + + params := []config.ChainParametersByEpochConfig{ + { + EnableEpoch: 0, + ShardConsensusGroupSize: 5, + ShardMinNumNodes: 7, + MetachainConsensusGroupSize: 7, + MetachainMinNumNodes: 7, + }, + { + EnableEpoch: 10, + ShardConsensusGroupSize: 50, + ShardMinNumNodes: 70, + MetachainConsensusGroupSize: 70, + MetachainMinNumNodes: 70, + }, + { + EnableEpoch: 100, + ShardConsensusGroupSize: 500, + ShardMinNumNodes: 700, + MetachainConsensusGroupSize: 700, + MetachainMinNumNodes: 700, + }, + } + + paramsHolder, _ := NewChainParametersHolder(ArgsChainParametersHolder{ + ChainParameters: params, + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, + ChainParametersNotifier: &commonmocks.ChainParametersNotifierStub{}, + }) + + for i := 0; i < 200; i++ { + res, _ := paramsHolder.ChainParametersForEpoch(uint32(i)) + if i < 10 { + require.Equal(t, uint32(5), res.ShardConsensusGroupSize) + require.Equal(t, uint32(7), res.MetachainConsensusGroupSize) + } else if i < 100 { + require.Equal(t, uint32(50), res.ShardConsensusGroupSize) + require.Equal(t, uint32(70), res.MetachainConsensusGroupSize) + } else { + require.Equal(t, uint32(500), res.ShardConsensusGroupSize) + require.Equal(t, uint32(700), res.MetachainConsensusGroupSize) + } + } + }) +} + +func TestChainParametersHolder_CurrentChainParameters(t *testing.T) { + t.Parallel() + + params := []config.ChainParametersByEpochConfig{ + { + EnableEpoch: 0, + ShardConsensusGroupSize: 5, + ShardMinNumNodes: 7, + MetachainConsensusGroupSize: 7, + MetachainMinNumNodes: 7, + }, + { + EnableEpoch: 10, + ShardConsensusGroupSize: 50, + ShardMinNumNodes: 70, + MetachainConsensusGroupSize: 70, + MetachainMinNumNodes: 70, + }, + } + + paramsHolder, _ := NewChainParametersHolder(ArgsChainParametersHolder{ + ChainParameters: params, + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, + ChainParametersNotifier: &commonmocks.ChainParametersNotifierStub{}, + }) + + paramsHolder.EpochStartAction(&block.MetaBlock{Epoch: 0}) + require.Equal(t, uint32(5), paramsHolder.CurrentChainParameters().ShardConsensusGroupSize) + + paramsHolder.EpochStartAction(&block.MetaBlock{Epoch: 3}) + require.Equal(t, uint32(5), paramsHolder.CurrentChainParameters().ShardConsensusGroupSize) + + paramsHolder.EpochStartAction(&block.MetaBlock{Epoch: 10}) + require.Equal(t, uint32(50), paramsHolder.CurrentChainParameters().ShardConsensusGroupSize) + + paramsHolder.EpochStartAction(&block.MetaBlock{Epoch: 999}) + require.Equal(t, uint32(50), paramsHolder.CurrentChainParameters().ShardConsensusGroupSize) +} + +func TestChainParametersHolder_AllChainParameters(t *testing.T) { + t.Parallel() + + params := []config.ChainParametersByEpochConfig{ + { + EnableEpoch: 0, + ShardConsensusGroupSize: 5, + ShardMinNumNodes: 7, + MetachainConsensusGroupSize: 7, + MetachainMinNumNodes: 7, + }, + { + EnableEpoch: 10, + ShardConsensusGroupSize: 50, + ShardMinNumNodes: 70, + MetachainConsensusGroupSize: 70, + MetachainMinNumNodes: 70, + }, + } + + paramsHolder, _ := NewChainParametersHolder(ArgsChainParametersHolder{ + ChainParameters: params, + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, + ChainParametersNotifier: &commonmocks.ChainParametersNotifierStub{}, + }) + + returnedAllChainsParameters := paramsHolder.AllChainParameters() + require.Equal(t, params, returnedAllChainsParameters) + require.NotEqual(t, fmt.Sprintf("%p", returnedAllChainsParameters), fmt.Sprintf("%p", paramsHolder.chainParameters)) +} + +func TestChainParametersHolder_ConcurrentOperations(t *testing.T) { + chainParams := make([]config.ChainParametersByEpochConfig, 0) + for i := uint32(0); i <= 100; i += 5 { + chainParams = append(chainParams, config.ChainParametersByEpochConfig{ + RoundDuration: 4000, + Hysteresis: 0.2, + EnableEpoch: i, + ShardConsensusGroupSize: i*10 + 1, + ShardMinNumNodes: i*10 + 1, + MetachainConsensusGroupSize: i*10 + 1, + MetachainMinNumNodes: i*10 + 1, + Adaptivity: false, + }) + } + + paramsHolder, _ := NewChainParametersHolder(ArgsChainParametersHolder{ + ChainParameters: chainParams, + EpochStartEventNotifier: &mock.EpochStartNotifierStub{}, + ChainParametersNotifier: &commonmocks.ChainParametersNotifierStub{}, + }) + + numOperations := 500 + wg := sync.WaitGroup{} + wg.Add(numOperations) + for i := 0; i < numOperations; i++ { + go func(idx int) { + switch idx { + case 0: + paramsHolder.EpochStartAction(&block.MetaBlock{Epoch: uint32(idx)}) + case 1: + _ = paramsHolder.CurrentChainParameters() + case 2: + _, _ = paramsHolder.ChainParametersForEpoch(uint32(idx)) + case 3: + _ = paramsHolder.AllChainParameters() + } + + wg.Done() + }(i % 4) + } + + wg.Wait() +} diff --git a/sharding/dtos.go b/sharding/dtos.go new file mode 100644 index 00000000000..33cb7c8e660 --- /dev/null +++ b/sharding/dtos.go @@ -0,0 +1,18 @@ +package sharding + +// ConsensusConfiguration holds the consensus configuration that can be used by both the shard and the metachain +type ConsensusConfiguration struct { + EnableEpoch uint32 + MinNodes uint32 + ConsensusGroupSize uint32 +} + +// NodesSetupDTO is the data transfer object used to map the nodes' configuration in regard to the genesis nodes setup +type NodesSetupDTO struct { + StartTime int64 `json:"startTime"` + RoundDuration uint64 `json:"roundDuration"` + Hysteresis float32 `json:"hysteresis"` + Adaptivity bool `json:"adaptivity"` + + InitialNodes []*InitialNode `json:"initialNodes"` +} diff --git a/sharding/errors.go b/sharding/errors.go index 8190d8ba4ec..e6c2c29984c 100644 --- a/sharding/errors.go +++ b/sharding/errors.go @@ -39,3 +39,24 @@ var ErrNilOwnPublicKey = errors.New("nil own public key") // ErrNilEndOfProcessingHandler signals that a nil end of processing handler has been provided var ErrNilEndOfProcessingHandler = errors.New("nil end of processing handler") + +// ErrNilChainParametersProvider signals that a nil chain parameters provider has been given +var ErrNilChainParametersProvider = errors.New("nil chain parameters provider") + +// ErrNilEpochStartEventNotifier signals that a nil epoch start event notifier has been provided +var ErrNilEpochStartEventNotifier = errors.New("nil epoch start event notifier") + +// ErrMissingChainParameters signals that a nil chain parameters array has been provided +var ErrMissingChainParameters = errors.New("empty chain parameters array") + +// ErrMissingConfigurationForEpochZero signals that no configuration for epoch 0 exists +var ErrMissingConfigurationForEpochZero = errors.New("missing configuration for epoch 0") + +// ErrNoMatchingConfigurationFound signals that no matching configuration is found +var ErrNoMatchingConfigurationFound = errors.New("no matching configuration found") + +// ErrNilChainParametersNotifier signals that a nil chain parameters notifier has been provided +var ErrNilChainParametersNotifier = errors.New("nil chain parameters notifier") + +// ErrInvalidChainParametersForEpoch signals that an invalid chain parameters for epoch has been provided +var ErrInvalidChainParametersForEpoch = errors.New("invalid chain parameters for epoch") diff --git a/sharding/interface.go b/sharding/interface.go index 40180ec3bb5..06191510539 100644 --- a/sharding/interface.go +++ b/sharding/interface.go @@ -1,6 +1,10 @@ package sharding -import "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" +import ( + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/epochStart" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" +) // Coordinator defines what a shard state coordinator should hold type Coordinator interface { @@ -61,5 +65,26 @@ type GenesisNodesSetupHandler interface { MinNumberOfNodesWithHysteresis() uint32 MinShardHysteresisNodes() uint32 MinMetaHysteresisNodes() uint32 + ExportNodesConfig() config.NodesConfig + IsInterfaceNil() bool +} + +// EpochStartEventNotifier provides Register and Unregister functionality for the end of epoch events +type EpochStartEventNotifier interface { + RegisterHandler(handler epochStart.ActionHandler) + UnregisterHandler(handler epochStart.ActionHandler) + IsInterfaceNil() bool +} + +// ChainParametersHandler defines the actions that need to be done by a component that can handle chain parameters +type ChainParametersHandler interface { + CurrentChainParameters() config.ChainParametersByEpochConfig + ChainParametersForEpoch(epoch uint32) (config.ChainParametersByEpochConfig, error) + IsInterfaceNil() bool +} + +// ChainParametersNotifierHandler defines the actions that need to be done by a component that can handle chain parameters changes +type ChainParametersNotifierHandler interface { + UpdateCurrentChainParameters(params config.ChainParametersByEpochConfig) IsInterfaceNil() bool } diff --git a/sharding/mock/epochHandlerMock.go b/sharding/mock/epochHandlerMock.go deleted file mode 100644 index 9b78066bd3e..00000000000 --- a/sharding/mock/epochHandlerMock.go +++ /dev/null @@ -1,16 +0,0 @@ -package mock - -// EpochHandlerMock - -type EpochHandlerMock struct { - EpochValue uint32 -} - -// Epoch - -func (ehm *EpochHandlerMock) Epoch() uint32 { - return ehm.EpochValue -} - -// IsInterfaceNil - -func (ehm *EpochHandlerMock) IsInterfaceNil() bool { - return ehm == nil -} diff --git a/sharding/mock/epochHandlerStub.go b/sharding/mock/epochHandlerStub.go deleted file mode 100644 index 4470eaca56c..00000000000 --- a/sharding/mock/epochHandlerStub.go +++ /dev/null @@ -1,20 +0,0 @@ -package mock - -// EpochHandlerStub - -type EpochHandlerStub struct { - EpochCalled func() uint32 -} - -// Epoch - -func (ehs *EpochHandlerStub) Epoch() uint32 { - if ehs.EpochCalled != nil { - return ehs.EpochCalled() - } - - return uint32(0) -} - -// IsInterfaceNil - -func (ehs *EpochHandlerStub) IsInterfaceNil() bool { - return ehs == nil -} diff --git a/sharding/mock/epochStartNotifierStub.go b/sharding/mock/epochStartNotifierStub.go deleted file mode 100644 index 53406c12920..00000000000 --- a/sharding/mock/epochStartNotifierStub.go +++ /dev/null @@ -1,47 +0,0 @@ -package mock - -import ( - "github.com/multiversx/mx-chain-core-go/data" - "github.com/multiversx/mx-chain-go/epochStart" -) - -// EpochStartNotifierStub - -type EpochStartNotifierStub struct { - RegisterHandlerCalled func(handler epochStart.ActionHandler) - UnregisterHandlerCalled func(handler epochStart.ActionHandler) - NotifyAllPrepareCalled func(hdr data.HeaderHandler, body data.BodyHandler, validatorInfoCacher epochStart.ValidatorInfoCacher) - NotifyAllCalled func(hdr data.HeaderHandler) -} - -// RegisterHandler - -func (esnm *EpochStartNotifierStub) RegisterHandler(handler epochStart.ActionHandler) { - if esnm.RegisterHandlerCalled != nil { - esnm.RegisterHandlerCalled(handler) - } -} - -// UnregisterHandler - -func (esnm *EpochStartNotifierStub) UnregisterHandler(handler epochStart.ActionHandler) { - if esnm.UnregisterHandlerCalled != nil { - esnm.UnregisterHandlerCalled(handler) - } -} - -// NotifyAllPrepare - -func (esnm *EpochStartNotifierStub) NotifyAllPrepare(metaHdr data.HeaderHandler, body data.BodyHandler, validatorInfoCacher epochStart.ValidatorInfoCacher) { - if esnm.NotifyAllPrepareCalled != nil { - esnm.NotifyAllPrepareCalled(metaHdr, body, validatorInfoCacher) - } -} - -// NotifyAll - -func (esnm *EpochStartNotifierStub) NotifyAll(hdr data.HeaderHandler) { - if esnm.NotifyAllCalled != nil { - esnm.NotifyAllCalled(hdr) - } -} - -// IsInterfaceNil - -func (esnm *EpochStartNotifierStub) IsInterfaceNil() bool { - return esnm == nil -} diff --git a/sharding/mock/hasherStub.go b/sharding/mock/hasherStub.go deleted file mode 100644 index f05c2fd2cc8..00000000000 --- a/sharding/mock/hasherStub.go +++ /dev/null @@ -1,28 +0,0 @@ -package mock - -// HasherStub - -type HasherStub struct { - ComputeCalled func(s string) []byte - EmptyHashCalled func() []byte - SizeCalled func() int -} - -// Compute will output the SHA's equivalent of the input string -func (hs *HasherStub) Compute(s string) []byte { - return hs.ComputeCalled(s) -} - -// EmptyHash will return the equivalent of empty string SHA's -func (hs *HasherStub) EmptyHash() []byte { - return hs.EmptyHashCalled() -} - -// Size returns the required size in bytes -func (hs *HasherStub) Size() int { - return hs.SizeCalled() -} - -// IsInterfaceNil returns true if there is no value under the interface -func (hs *HasherStub) IsInterfaceNil() bool { - return hs == nil -} diff --git a/sharding/mock/listIndexUpdaterStub.go b/sharding/mock/listIndexUpdaterStub.go deleted file mode 100644 index 31c5ae19b76..00000000000 --- a/sharding/mock/listIndexUpdaterStub.go +++ /dev/null @@ -1,20 +0,0 @@ -package mock - -// ListIndexUpdaterStub - -type ListIndexUpdaterStub struct { - UpdateListAndIndexCalled func(pubKey string, shardID uint32, list string, index uint32) error -} - -// UpdateListAndIndex - -func (lius *ListIndexUpdaterStub) UpdateListAndIndex(pubKey string, shardID uint32, list string, index uint32) error { - if lius.UpdateListAndIndexCalled != nil { - return lius.UpdateListAndIndexCalled(pubKey, shardID, list, index) - } - - return nil -} - -// IsInterfaceNil returns true if there is no value under the interface -func (lius *ListIndexUpdaterStub) IsInterfaceNil() bool { - return lius == nil -} diff --git a/sharding/mock/multipleShardsCoordinatorFake.go b/sharding/mock/multipleShardsCoordinatorFake.go deleted file mode 100644 index 89a145beca2..00000000000 --- a/sharding/mock/multipleShardsCoordinatorFake.go +++ /dev/null @@ -1,94 +0,0 @@ -package mock - -import ( - "fmt" - "math" -) - -type multipleShardsCoordinatorFake struct { - numOfShards uint32 - CurrentShard uint32 - maskHigh uint32 - maskLow uint32 -} - -// NewMultipleShardsCoordinatorFake - -func NewMultipleShardsCoordinatorFake(numOfShards uint32, currentShard uint32) *multipleShardsCoordinatorFake { - mscf := &multipleShardsCoordinatorFake{ - numOfShards: numOfShards, - CurrentShard: currentShard, - } - mscf.maskHigh, mscf.maskLow = mscf.calculateMasks() - return mscf -} - -func (mscf *multipleShardsCoordinatorFake) calculateMasks() (uint32, uint32) { - n := math.Ceil(math.Log2(float64(mscf.numOfShards))) - return (1 << uint(n)) - 1, (1 << uint(n-1)) - 1 -} - -// NumberOfShards - -func (mscf *multipleShardsCoordinatorFake) NumberOfShards() uint32 { - return mscf.numOfShards -} - -// ComputeId - -func (mscf *multipleShardsCoordinatorFake) ComputeId(address []byte) uint32 { - bytesNeed := int(mscf.numOfShards/256) + 1 - startingIndex := 0 - if len(address) > bytesNeed { - startingIndex = len(address) - bytesNeed - } - - buffNeeded := address[startingIndex:] - - addr := uint32(0) - for i := 0; i < len(buffNeeded); i++ { - addr = addr<<8 + uint32(buffNeeded[i]) - } - - shard := addr & mscf.maskHigh - if shard > mscf.numOfShards-1 { - shard = addr & mscf.maskLow - } - return shard -} - -// SelfId - -func (mscf *multipleShardsCoordinatorFake) SelfId() uint32 { - return mscf.CurrentShard -} - -// SetSelfId - -func (mscf *multipleShardsCoordinatorFake) SetSelfId(_ uint32) error { - return nil -} - -// SameShard - -func (mscf *multipleShardsCoordinatorFake) SameShard(_, _ []byte) bool { - return true -} - -// SetNoShards - -func (mscf *multipleShardsCoordinatorFake) SetNoShards(numOfShards uint32) { - mscf.numOfShards = numOfShards -} - -// CommunicationIdentifier returns the identifier between current shard ID and destination shard ID -// identifier is generated such as the first shard from identifier is always smaller than the last -func (mscf *multipleShardsCoordinatorFake) CommunicationIdentifier(destShardID uint32) string { - if destShardID == mscf.CurrentShard { - return fmt.Sprintf("_%d", mscf.CurrentShard) - } - - if destShardID < mscf.CurrentShard { - return fmt.Sprintf("_%d_%d", destShardID, mscf.CurrentShard) - } - - return fmt.Sprintf("_%d_%d", mscf.CurrentShard, destShardID) -} - -// IsInterfaceNil returns true if there is no value under the interface -func (mscf *multipleShardsCoordinatorFake) IsInterfaceNil() bool { - return mscf == nil -} diff --git a/sharding/mock/pubkeyConverterMock.go b/sharding/mock/pubkeyConverterMock.go index e81d21ff4f6..2679da82d02 100644 --- a/sharding/mock/pubkeyConverterMock.go +++ b/sharding/mock/pubkeyConverterMock.go @@ -8,7 +8,8 @@ import ( // PubkeyConverterMock - type PubkeyConverterMock struct { - len int + len int + DecodeCalled func() ([]byte, error) } // NewPubkeyConverterMock - @@ -20,6 +21,9 @@ func NewPubkeyConverterMock(addressLen int) *PubkeyConverterMock { // Decode - func (pcm *PubkeyConverterMock) Decode(humanReadable string) ([]byte, error) { + if pcm.DecodeCalled != nil { + return pcm.DecodeCalled() + } return hex.DecodeString(humanReadable) } diff --git a/sharding/mock/testdata/invalidNodesSetupMock.json b/sharding/mock/testdata/invalidNodesSetupMock.json deleted file mode 100644 index 67458949a71..00000000000 --- a/sharding/mock/testdata/invalidNodesSetupMock.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "startTime": 0, - "roundDuration": 4000, - "consensusGroupSize": 0, - "minNodesPerShard": 1, - "metaChainActive" : true, - "metaChainConsensusGroupSize" : 1, - "metaChainMinNodes" : 1, - "initialNodes": [ - { - "pubkey": "41378f754e2c7b2745208c3ed21b151d297acdc84c3aca00b9e292cf28ec2d444771070157ea7760ed83c26f4fed387d0077e00b563a95825dac2cbc349fc0025ccf774e37b0a98ad9724d30e90f8c29b4091ccb738ed9ffc0573df776ee9ea30b3c038b55e532760ea4a8f152f2a52848020e5cee1cc537f2c2323399723081", - "address": "9e95a4e46da335a96845b4316251fc1bb197e1b8136d96ecc62bf6604eca9e49" - }, - { - "pubkey": "52f3bf5c01771f601ec2137e267319ab6716ef6ff5dfddaea48b42d955f631167f2ce19296a202bb8fd174f4e94f8c85f619df85a7f9f8de0f3768e5e6d8c48187b767deccf9829be246aa331aa86d182eb8fa28ea8a3e45d357ed1647a9be020a5569d686253a6f89e9123c7f21f302e82f67d3e3cd69cf267b9910a663ef32", - "address": "7a330039e77ca06bc127319fd707cc4911a80db489a39fcfb746283a05f61836" - } - ] -} diff --git a/sharding/mock/testdata/nodesSetupMock.json b/sharding/mock/testdata/nodesSetupMock.json deleted file mode 100644 index 17cf384c5b4..00000000000 --- a/sharding/mock/testdata/nodesSetupMock.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "startTime": 0, - "roundDuration": 4000, - "consensusGroupSize": 1, - "minNodesPerShard": 1, - "metaChainActive" : true, - "metaChainConsensusGroupSize" : 1, - "metaChainMinNodes" : 1, - "initialNodes": [ - { - "pubkey": "41378f754e2c7b2745208c3ed21b151d297acdc84c3aca00b9e292cf28ec2d444771070157ea7760ed83c26f4fed387d0077e00b563a95825dac2cbc349fc0025ccf774e37b0a98ad9724d30e90f8c29b4091ccb738ed9ffc0573df776ee9ea30b3c038b55e532760ea4a8f152f2a52848020e5cee1cc537f2c2323399723081", - "address": "9e95a4e46da335a96845b4316251fc1bb197e1b8136d96ecc62bf6604eca9e49" - }, - { - "pubkey": "52f3bf5c01771f601ec2137e267319ab6716ef6ff5dfddaea48b42d955f631167f2ce19296a202bb8fd174f4e94f8c85f619df85a7f9f8de0f3768e5e6d8c48187b767deccf9829be246aa331aa86d182eb8fa28ea8a3e45d357ed1647a9be020a5569d686253a6f89e9123c7f21f302e82f67d3e3cd69cf267b9910a663ef32", - "address": "7a330039e77ca06bc127319fd707cc4911a80db489a39fcfb746283a05f61836" - }, - { - "pubkey": "5e91c426c5c8f5f805f86de1e0653e2ec33853772e583b88e9f0f201089d03d8570759c3c3ab610ce573493c33ba0adf954c8939dba5d5ef7f2be4e87145d8153fc5b4fb91cecb8d9b1f62e080743fbf69c8c3096bf07980bb82cb450ba9b902673373d5b671ea73620cc5bc4d36f7a0f5ca3684d4c8aa5c1b425ab2a8673140", - "address": "131e2e717f2d33bdf7850c12b03dfe41ea8a5e76fdd6d4f23aebe558603e746f" - }, - { - "pubkey": "73972bf46dca59fba211c58f11b530f8e9d6392c499655ce760abc6458fd9c6b54b9676ee4b95aa32f6c254c9aad2f63a6195cd65d837a4320d7b8e915ba3a7123c8f4983b201035573c0752bb54e9021eb383b40d302447b62ea7a3790c89c47f5ab81d183f414e87611a31ff635ad22e969495356d5bc44eec7917aaad4c5e", - "address": "4c9e66b605882c1099088f26659692f084e41dc0dedfaedf6a6409af21c02aac" - }, - { - "pubkey": "7391ccce066ab5674304b10220643bc64829afa626a165f1e7a6618e260fa68f8e79018ac5964f7a1b8dd419645049042e34ebe7f2772def71e6176ce9daf50a57c17ee2a7445b908fe47e8f978380fcc2654a19925bf73db2402b09dde515148081f8ca7c331fbedec689de1b7bfce6bf106e4433557c29752c12d0a009f47a", - "address": "90a66900634b206d20627fbaec432ebfbabeaf30b9e338af63191435e2e37022" - } - ] -} diff --git a/sharding/networksharding/peerShardMapper_test.go b/sharding/networksharding/peerShardMapper_test.go index fef620ed90d..6b03abe6805 100644 --- a/sharding/networksharding/peerShardMapper_test.go +++ b/sharding/networksharding/peerShardMapper_test.go @@ -9,23 +9,24 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/sharding/networksharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/p2pmocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" - "github.com/stretchr/testify/assert" ) // ------- NewPeerShardMapper func createMockArgumentForPeerShardMapper() networksharding.ArgPeerShardMapper { return networksharding.ArgPeerShardMapper{ - PeerIdPkCache: testscommon.NewCacherMock(), - FallbackPkShardCache: testscommon.NewCacherMock(), - FallbackPidShardCache: testscommon.NewCacherMock(), + PeerIdPkCache: cache.NewCacherMock(), + FallbackPkShardCache: cache.NewCacherMock(), + FallbackPidShardCache: cache.NewCacherMock(), NodesCoordinator: &shardingMocks.NodesCoordinatorMock{}, PreferredPeersHolder: &p2pmocks.PeersHolderStub{}, } diff --git a/sharding/nodesCoordinator/common_test.go b/sharding/nodesCoordinator/common_test.go index 50be55fd1ae..b7902db0c7e 100644 --- a/sharding/nodesCoordinator/common_test.go +++ b/sharding/nodesCoordinator/common_test.go @@ -5,10 +5,13 @@ import ( "encoding/binary" "fmt" "math/big" + "strconv" "testing" - "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" ) func TestComputeStartIndexAndNumAppearancesForValidator(t *testing.T) { @@ -151,3 +154,105 @@ func getExpandedEligibleList(num int) []uint32 { func newValidatorMock(pubKey []byte, chances uint32, index uint32) *validator { return &validator{pubKey: pubKey, index: index, chances: chances} } + +func TestSerializableShardValidatorListToValidatorListShouldErrNilPubKey(t *testing.T) { + t.Parallel() + + listOfSerializableValidators := []*SerializableValidator{ + { + PubKey: nil, + Chances: 1, + Index: 1, + }, + } + + _, err := SerializableShardValidatorListToValidatorList(listOfSerializableValidators) + require.Equal(t, ErrNilPubKey, err) +} + +func TestSerializableShardValidatorListToValidatorListShouldWork(t *testing.T) { + t.Parallel() + + listOfSerializableValidators := []*SerializableValidator{ + { + PubKey: []byte("pubkey"), + Chances: 1, + Index: 1, + }, + } + + expectedListOfValidators := make([]Validator, 1) + v, _ := NewValidator(listOfSerializableValidators[0].PubKey, listOfSerializableValidators[0].Chances, listOfSerializableValidators[0].Index) + require.NotNil(t, v) + expectedListOfValidators[0] = v + + valReturned, err := SerializableShardValidatorListToValidatorList(listOfSerializableValidators) + + require.Nil(t, err) + require.Equal(t, expectedListOfValidators, valReturned) +} + +func TestSerializableValidatorsToValidatorsShouldWork(t *testing.T) { + t.Parallel() + + mapOfSerializableValidators := make(map[string][]*SerializableValidator, 1) + mapOfSerializableValidators["1"] = []*SerializableValidator{ + { + PubKey: []byte("pubkey"), + Chances: 1, + Index: 1, + }, + } + + expectedMapOfValidators := make(map[uint32][]Validator, 1) + + v, _ := NewValidator(mapOfSerializableValidators["1"][0].PubKey, mapOfSerializableValidators["1"][0].Chances, mapOfSerializableValidators["1"][0].Index) + expectedMapOfValidators[uint32(1)] = []Validator{v} + + require.NotNil(t, v) + + valReturned, err := SerializableValidatorsToValidators(mapOfSerializableValidators) + + require.Nil(t, err) + require.Equal(t, expectedMapOfValidators, valReturned) +} + +func TestSerializableValidatorsToValidatorsShouldErrNilPubKey(t *testing.T) { + t.Parallel() + + mapOfSerializableValidators := make(map[string][]*SerializableValidator, 1) + mapOfSerializableValidators["1"] = []*SerializableValidator{ + { + PubKey: nil, + Chances: 1, + Index: 1, + }, + } + + _, err := SerializableValidatorsToValidators(mapOfSerializableValidators) + + require.Equal(t, ErrNilPubKey, err) +} + +func TestSerializableValidatorsToValidatorsShouldErrEmptyString(t *testing.T) { + t.Parallel() + + mapOfSerializableValidators := make(map[string][]*SerializableValidator, 1) + mapOfSerializableValidators[""] = []*SerializableValidator{ + { + PubKey: []byte("pubkey"), + Chances: 1, + Index: 1, + }, + } + + expectedMapOfValidators := make(map[uint32][]Validator, 1) + + v, _ := NewValidator(mapOfSerializableValidators[""][0].PubKey, mapOfSerializableValidators[""][0].Chances, mapOfSerializableValidators[""][0].Index) + require.NotNil(t, v) + expectedMapOfValidators[uint32(1)] = []Validator{v} + + _, err := SerializableValidatorsToValidators(mapOfSerializableValidators) + + require.Equal(t, &strconv.NumError{Func: "ParseUint", Num: "", Err: strconv.ErrSyntax}, err) +} diff --git a/sharding/nodesCoordinator/dtos.go b/sharding/nodesCoordinator/dtos.go index 75c28194a6a..5bd82fd0432 100644 --- a/sharding/nodesCoordinator/dtos.go +++ b/sharding/nodesCoordinator/dtos.go @@ -1,7 +1,10 @@ package nodesCoordinator +import "github.com/multiversx/mx-chain-go/config" + // ArgsUpdateNodes holds the parameters required by the shuffler to generate a new nodes configuration type ArgsUpdateNodes struct { + ChainParameters config.ChainParametersByEpochConfig Eligible map[uint32][]Validator Waiting map[uint32][]Validator NewNodes []Validator diff --git a/sharding/nodesCoordinator/errors.go b/sharding/nodesCoordinator/errors.go index 3d063f4605e..901559116ab 100644 --- a/sharding/nodesCoordinator/errors.go +++ b/sharding/nodesCoordinator/errors.go @@ -40,9 +40,6 @@ var ErrNilPreviousEpochConfig = errors.New("nil previous epoch config") // ErrEpochNodesConfigDoesNotExist signals that the epoch nodes configuration is missing var ErrEpochNodesConfigDoesNotExist = errors.New("epoch nodes configuration does not exist") -// ErrInvalidConsensusGroupSize signals that the consensus size is invalid (e.g. value is negative) -var ErrInvalidConsensusGroupSize = errors.New("invalid consensus group size") - // ErrNilRandomness signals that a nil randomness source has been provided var ErrNilRandomness = errors.New("nil randomness source") @@ -123,3 +120,9 @@ var ErrReceivedAuctionValidatorsBeforeStakingV4 = errors.New("should not have re // ErrNilEpochNotifier signals that a nil EpochNotifier has been provided var ErrNilEpochNotifier = errors.New("nil epoch notifier provided") + +// ErrNilChainParametersHandler signals that a nil chain parameters handler has been provided +var ErrNilChainParametersHandler = errors.New("nil chain parameters handler") + +// ErrEmptyValidatorsList signals that the validators list is empty +var ErrEmptyValidatorsList = errors.New("empty validators list") diff --git a/sharding/nodesCoordinator/export_test.go b/sharding/nodesCoordinator/export_test.go new file mode 100644 index 00000000000..365a4bc4322 --- /dev/null +++ b/sharding/nodesCoordinator/export_test.go @@ -0,0 +1,9 @@ +package nodesCoordinator + +// AddDummyEpoch adds the epoch in the cached map +func (ihnc *indexHashedNodesCoordinator) AddDummyEpoch(epoch uint32) { + ihnc.mutNodesConfig.Lock() + defer ihnc.mutNodesConfig.Unlock() + + ihnc.nodesConfig[epoch] = &epochNodesConfig{} +} diff --git a/sharding/nodesCoordinator/hashValidatorShuffler.go b/sharding/nodesCoordinator/hashValidatorShuffler.go index 71d2b5351b3..27e2459e2dd 100644 --- a/sharding/nodesCoordinator/hashValidatorShuffler.go +++ b/sharding/nodesCoordinator/hashValidatorShuffler.go @@ -20,10 +20,6 @@ var _ NodesShuffler = (*randHashShuffler)(nil) // NodesShufflerArgs defines the arguments required to create a nodes shuffler type NodesShufflerArgs struct { - NodesShard uint32 - NodesMeta uint32 - Hysteresis float32 - Adaptivity bool ShuffleBetweenShards bool MaxNodesEnableConfig []config.MaxNodesChangeConfig EnableEpochsHandler common.EnableEpochsHandler @@ -65,11 +61,6 @@ type randHashShuffler struct { // when reinitialization of node in new shard is implemented shuffleBetweenShards bool - adaptivity bool - nodesShard uint32 - nodesMeta uint32 - shardHysteresis uint32 - metaHysteresis uint32 activeNodesConfig config.MaxNodesChangeConfig availableNodesConfigs []config.MaxNodesChangeConfig mutShufflerParams sync.RWMutex @@ -118,8 +109,6 @@ func NewHashValidatorsShuffler(args *NodesShufflerArgs) (*randHashShuffler, erro stakingV4Step3EnableEpoch: args.EnableEpochs.StakingV4Step3EnableEpoch, } - rxs.UpdateParams(args.NodesShard, args.NodesMeta, args.Hysteresis, args.Adaptivity) - if rxs.shuffleBetweenShards { rxs.validatorDistributor = &CrossShardValidatorDistributor{} } else { @@ -131,27 +120,6 @@ func NewHashValidatorsShuffler(args *NodesShufflerArgs) (*randHashShuffler, erro return rxs, nil } -// UpdateParams updates the shuffler parameters -// Should be called when new params are agreed through governance -func (rhs *randHashShuffler) UpdateParams( - nodesShard uint32, - nodesMeta uint32, - hysteresis float32, - adaptivity bool, -) { - // TODO: are there constraints we want to enforce? e.g min/max hysteresis - shardHysteresis := uint32(float32(nodesShard) * hysteresis) - metaHysteresis := uint32(float32(nodesMeta) * hysteresis) - - rhs.mutShufflerParams.Lock() - rhs.shardHysteresis = shardHysteresis - rhs.metaHysteresis = metaHysteresis - rhs.nodesShard = nodesShard - rhs.nodesMeta = nodesMeta - rhs.adaptivity = adaptivity - rhs.mutShufflerParams.Unlock() -} - // UpdateNodeLists shuffles the nodes and returns the lists with the new nodes configuration // The function needs to ensure that: // 1. Old eligible nodes list will have up to shuffleOutThreshold percent nodes shuffled out from each shard @@ -169,14 +137,17 @@ func (rhs *randHashShuffler) UpdateParams( // execute the shard merge // c) No change in the number of shards then nothing extra needs to be done func (rhs *randHashShuffler) UpdateNodeLists(args ArgsUpdateNodes) (*ResUpdateNodes, error) { - rhs.updateShufflerConfig(args.Epoch) + chainParameters := args.ChainParameters + + rhs.UpdateShufflerConfig(args.Epoch, chainParameters) eligibleAfterReshard := copyValidatorMap(args.Eligible) waitingAfterReshard := copyValidatorMap(args.Waiting) - args.AdditionalLeaving = removeDupplicates(args.UnStakeLeaving, args.AdditionalLeaving) + args.AdditionalLeaving = removeDuplicates(args.UnStakeLeaving, args.AdditionalLeaving) totalLeavingNum := len(args.AdditionalLeaving) + len(args.UnStakeLeaving) newNbShards := rhs.computeNewShards( + chainParameters, args.Eligible, args.Waiting, len(args.NewNodes), @@ -185,10 +156,10 @@ func (rhs *randHashShuffler) UpdateNodeLists(args ArgsUpdateNodes) (*ResUpdateNo ) rhs.mutShufflerParams.RLock() - canSplit := rhs.adaptivity && newNbShards > args.NbShards - canMerge := rhs.adaptivity && newNbShards < args.NbShards - nodesPerShard := rhs.nodesShard - nodesMeta := rhs.nodesMeta + canSplit := chainParameters.Adaptivity && newNbShards > args.NbShards + canMerge := chainParameters.Adaptivity && newNbShards < args.NbShards + nodesPerShard := chainParameters.ShardMinNumNodes + nodesMeta := chainParameters.MetachainMinNumNodes rhs.mutShufflerParams.RUnlock() if canSplit { @@ -219,7 +190,7 @@ func (rhs *randHashShuffler) UpdateNodeLists(args ArgsUpdateNodes) (*ResUpdateNo }) } -func removeDupplicates(unstake []Validator, additionalLeaving []Validator) []Validator { +func removeDuplicates(unstake []Validator, additionalLeaving []Validator) []Validator { additionalCopy := make([]Validator, 0, len(additionalLeaving)) additionalCopy = append(additionalCopy, additionalLeaving...) @@ -451,6 +422,7 @@ func removeLeavingNodesFromValidatorMaps( // computeNewShards determines the new number of shards based on the number of nodes in the network func (rhs *randHashShuffler) computeNewShards( + chainParameters config.ChainParametersByEpochConfig, eligible map[uint32][]Validator, waiting map[uint32][]Validator, numNewNodes int, @@ -468,10 +440,10 @@ func (rhs *randHashShuffler) computeNewShards( nodesNewEpoch := uint32(nbEligible + nbWaiting + numNewNodes - numLeavingNodes) rhs.mutShufflerParams.RLock() - maxNodesMeta := rhs.nodesMeta + rhs.metaHysteresis - maxNodesShard := rhs.nodesShard + rhs.shardHysteresis + maxNodesMeta := chainParameters.MetachainMinNumNodes + rhs.metaHysteresis(chainParameters) + maxNodesShard := chainParameters.ShardMinNumNodes + rhs.shardHysteresis(chainParameters) nodesForSplit := (nbShards+1)*maxNodesShard + maxNodesMeta - nodesForMerge := nbShards*rhs.nodesShard + rhs.nodesMeta + nodesForMerge := nbShards*chainParameters.ShardMinNumNodes + chainParameters.MetachainMinNumNodes rhs.mutShufflerParams.RUnlock() nbShardsNew := nbShards @@ -489,6 +461,14 @@ func (rhs *randHashShuffler) computeNewShards( return nbShardsNew } +func (rhs *randHashShuffler) metaHysteresis(chainParameters config.ChainParametersByEpochConfig) uint32 { + return uint32(chainParameters.Hysteresis * float32(chainParameters.MetachainMinNumNodes)) +} + +func (rhs *randHashShuffler) shardHysteresis(chainParameters config.ChainParametersByEpochConfig) uint32 { + return uint32(chainParameters.Hysteresis * float32(chainParameters.ShardMinNumNodes)) +} + // shuffleOutNodes shuffles the list of eligible validators in each shard and returns the map of shuffled out // validators func shuffleOutNodes( @@ -828,11 +808,11 @@ func sortKeys(nodes map[uint32][]Validator) []uint32 { return keys } -// updateShufflerConfig updates the shuffler config according to the current epoch. -func (rhs *randHashShuffler) updateShufflerConfig(epoch uint32) { +// UpdateShufflerConfig updates the shuffler config according to the current epoch. +func (rhs *randHashShuffler) UpdateShufflerConfig(epoch uint32, chainParameters config.ChainParametersByEpochConfig) { rhs.mutShufflerParams.Lock() defer rhs.mutShufflerParams.Unlock() - rhs.activeNodesConfig.NodesToShufflePerShard = rhs.nodesShard + rhs.activeNodesConfig.NodesToShufflePerShard = chainParameters.ShardMinNumNodes for _, maxNodesConfig := range rhs.availableNodesConfigs { if epoch >= maxNodesConfig.EpochEnable { rhs.activeNodesConfig = maxNodesConfig diff --git a/sharding/nodesCoordinator/hashValidatorShuffler_test.go b/sharding/nodesCoordinator/hashValidatorShuffler_test.go index 788ec3f9b59..089441bc287 100644 --- a/sharding/nodesCoordinator/hashValidatorShuffler_test.go +++ b/sharding/nodesCoordinator/hashValidatorShuffler_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/sharding/mock" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -28,6 +29,52 @@ const ( waitingPerShard = 30 ) +type testChainParametersCreator struct { + numNodesShards uint32 + numNodesMeta uint32 + hysteresis float32 + adaptivity bool +} + +func (t testChainParametersCreator) build() ChainParametersHandler { + return &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + EnableEpoch: 0, + Hysteresis: t.hysteresis, + ShardMinNumNodes: t.numNodesShards, + MetachainMinNumNodes: t.numNodesMeta, + ShardConsensusGroupSize: t.numNodesShards, + MetachainConsensusGroupSize: t.numNodesMeta, + Adaptivity: t.adaptivity, + } + }, + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + EnableEpoch: 0, + Hysteresis: t.hysteresis, + ShardMinNumNodes: t.numNodesShards, + MetachainMinNumNodes: t.numNodesMeta, + ShardConsensusGroupSize: t.numNodesShards, + MetachainConsensusGroupSize: t.numNodesMeta, + Adaptivity: t.adaptivity, + }, nil + }, + } +} + +func getTestChainParameters() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + EnableEpoch: 0, + Hysteresis: hysteresis, + ShardConsensusGroupSize: eligiblePerShard, + ShardMinNumNodes: eligiblePerShard, + MetachainConsensusGroupSize: eligiblePerShard, + MetachainMinNumNodes: eligiblePerShard, + Adaptivity: false, + } +} + func generateRandomByteArray(size int) []byte { r := make([]byte, size) _, _ = rand.Read(r) @@ -188,10 +235,6 @@ func testShuffledOut( func createHashShufflerInter() (*randHashShuffler, error) { shufflerArgs := &NodesShufflerArgs{ - NodesShard: eligiblePerShard, - NodesMeta: eligiblePerShard, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: true, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, EnableEpochs: config.EnableEpochs{ @@ -207,10 +250,6 @@ func createHashShufflerInter() (*randHashShuffler, error) { func createHashShufflerIntraShards() (*randHashShuffler, error) { shufflerArgs := &NodesShufflerArgs{ - NodesShard: eligiblePerShard, - NodesMeta: eligiblePerShard, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, EnableEpochs: config.EnableEpochs{ StakingV4Step2EnableEpoch: 443, @@ -589,8 +628,8 @@ func Test_removeValidatorsFromListRemoveFromLastMaxGreater(t *testing.T) { func Test_removeValidatorsFromListRandomValidatorsMaxSmaller(t *testing.T) { t.Parallel() - nbValidatotrsToRemove := 10 - maxToRemove := nbValidatotrsToRemove - 3 + nbValidatorsToRemove := 10 + maxToRemove := nbValidatorsToRemove - 3 validators := generateValidatorList(30) validatorsCopy := make([]Validator, len(validators)) validatorsToRemove := make([]Validator, 0) @@ -599,7 +638,7 @@ func Test_removeValidatorsFromListRandomValidatorsMaxSmaller(t *testing.T) { sort.Sort(validatorList(validators)) - validatorsToRemove = append(validatorsToRemove, validators[:nbValidatotrsToRemove]...) + validatorsToRemove = append(validatorsToRemove, validators[:nbValidatorsToRemove]...) v, removed := removeValidatorsFromList(validators, validatorsToRemove, maxToRemove) testRemoveValidators(t, validatorsCopy, validatorsToRemove, v, removed, maxToRemove) @@ -608,8 +647,8 @@ func Test_removeValidatorsFromListRandomValidatorsMaxSmaller(t *testing.T) { func Test_removeValidatorsFromListRandomValidatorsMaxGreater(t *testing.T) { t.Parallel() - nbValidatotrsToRemove := 10 - maxToRemove := nbValidatotrsToRemove + 3 + nbValidatorsToRemove := 10 + maxToRemove := nbValidatorsToRemove + 3 validators := generateValidatorList(30) validatorsCopy := make([]Validator, len(validators)) validatorsToRemove := make([]Validator, 0) @@ -618,13 +657,13 @@ func Test_removeValidatorsFromListRandomValidatorsMaxGreater(t *testing.T) { sort.Sort(validatorList(validators)) - validatorsToRemove = append(validatorsToRemove, validators[:nbValidatotrsToRemove]...) + validatorsToRemove = append(validatorsToRemove, validators[:nbValidatorsToRemove]...) v, removed := removeValidatorsFromList(validators, validatorsToRemove, maxToRemove) testRemoveValidators(t, validatorsCopy, validatorsToRemove, v, removed, maxToRemove) } -func Test_removeDupplicates_NoDupplicates(t *testing.T) { +func Test_removeDuplicates_NoDuplicates(t *testing.T) { t.Parallel() firstList := generateValidatorList(30) @@ -636,13 +675,13 @@ func Test_removeDupplicates_NoDupplicates(t *testing.T) { secondListCopy := make([]Validator, len(secondList)) copy(secondListCopy, secondList) - secondListAfterRemove := removeDupplicates(firstList, secondList) + secondListAfterRemove := removeDuplicates(firstList, secondList) assert.Equal(t, firstListCopy, firstList) assert.Equal(t, secondListCopy, secondListAfterRemove) } -func Test_removeDupplicates_SomeDupplicates(t *testing.T) { +func Test_removeDuplicates_SomeDuplicates(t *testing.T) { t.Parallel() firstList := generateValidatorList(30) @@ -656,14 +695,14 @@ func Test_removeDupplicates_SomeDupplicates(t *testing.T) { secondListCopy := make([]Validator, len(secondList)) copy(secondListCopy, secondList) - secondListAfterRemove := removeDupplicates(firstList, secondList) + secondListAfterRemove := removeDuplicates(firstList, secondList) assert.Equal(t, firstListCopy, firstList) assert.Equal(t, len(secondListCopy)-len(validatorsFromFirstList), len(secondListAfterRemove)) assert.Equal(t, secondListCopy[:20], secondListAfterRemove) } -func Test_removeDupplicates_AllDupplicates(t *testing.T) { +func Test_removeDuplicates_AllDuplicates(t *testing.T) { t.Parallel() firstList := generateValidatorList(30) @@ -675,7 +714,7 @@ func Test_removeDupplicates_AllDupplicates(t *testing.T) { secondListCopy := make([]Validator, len(secondList)) copy(secondListCopy, secondList) - secondListAfterRemove := removeDupplicates(firstList, secondList) + secondListAfterRemove := removeDuplicates(firstList, secondList) assert.Equal(t, firstListCopy, firstList) assert.Equal(t, len(secondListCopy)-len(firstListCopy), len(secondListAfterRemove)) @@ -1078,10 +1117,6 @@ func TestNewHashValidatorsShuffler(t *testing.T) { t.Parallel() shufflerArgs := &NodesShufflerArgs{ - NodesShard: eligiblePerShard, - NodesMeta: eligiblePerShard, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -1098,10 +1133,12 @@ func TestRandHashShuffler_computeNewShardsNotChanging(t *testing.T) { shuffler, err := createHashShufflerInter() require.Nil(t, err) - eligible := generateValidatorMap(int(shuffler.nodesShard), currentNbShards) + testChainParams := getTestChainParameters() + + eligible := generateValidatorMap(int(getTestChainParameters().ShardMinNumNodes), currentNbShards) nbShards := currentNbShards + 1 // account for meta - maxNodesNoSplit := (nbShards + 1) * (shuffler.nodesShard + shuffler.shardHysteresis) - nbWaitingPerShard := int(maxNodesNoSplit/nbShards - shuffler.nodesShard) + maxNodesNoSplit := (nbShards + 1) * (testChainParams.ShardMinNumNodes + shuffler.shardHysteresis(testChainParams)) + nbWaitingPerShard := int(maxNodesNoSplit/nbShards - testChainParams.ShardMinNumNodes) waiting := generateValidatorMap(nbWaitingPerShard, currentNbShards) newNodes := generateValidatorList(0) leavingUnstake := generateValidatorList(0) @@ -1110,7 +1147,7 @@ func TestRandHashShuffler_computeNewShardsNotChanging(t *testing.T) { numNewNodes := len(newNodes) numLeaving := len(leavingUnstake) + len(leavingRating) - newNbShards := shuffler.computeNewShards(eligible, waiting, numNewNodes, numLeaving, currentNbShards) + newNbShards := shuffler.computeNewShards(testChainParams, eligible, waiting, numNewNodes, numLeaving, currentNbShards) assert.Equal(t, currentNbShards, newNbShards) } @@ -1121,10 +1158,11 @@ func TestRandHashShuffler_computeNewShardsWithSplit(t *testing.T) { shuffler, err := createHashShufflerInter() require.Nil(t, err) - eligible := generateValidatorMap(int(shuffler.nodesShard), currentNbShards) + testChainParams := getTestChainParameters() + eligible := generateValidatorMap(int(testChainParams.ShardMinNumNodes), currentNbShards) nbShards := currentNbShards + 1 // account for meta - maxNodesNoSplit := (nbShards + 1) * (shuffler.nodesShard + shuffler.shardHysteresis) - nbWaitingPerShard := int(maxNodesNoSplit/nbShards-shuffler.nodesShard) + 1 + maxNodesNoSplit := (nbShards + 1) * (testChainParams.ShardMinNumNodes + shuffler.shardHysteresis(testChainParams)) + nbWaitingPerShard := int(maxNodesNoSplit/nbShards-testChainParams.ShardMinNumNodes) + 1 waiting := generateValidatorMap(nbWaitingPerShard, currentNbShards) newNodes := generateValidatorList(0) leavingUnstake := generateValidatorList(0) @@ -1133,7 +1171,7 @@ func TestRandHashShuffler_computeNewShardsWithSplit(t *testing.T) { numNewNodes := len(newNodes) numLeaving := len(leavingUnstake) + len(leavingRating) - newNbShards := shuffler.computeNewShards(eligible, waiting, numNewNodes, numLeaving, currentNbShards) + newNbShards := shuffler.computeNewShards(testChainParams, eligible, waiting, numNewNodes, numLeaving, currentNbShards) assert.Equal(t, currentNbShards+1, newNbShards) } @@ -1144,7 +1182,7 @@ func TestRandHashShuffler_computeNewShardsWithMerge(t *testing.T) { shuffler, err := createHashShufflerInter() require.Nil(t, err) - eligible := generateValidatorMap(int(shuffler.nodesShard), currentNbShards) + eligible := generateValidatorMap(int(getTestChainParameters().ShardMinNumNodes), currentNbShards) nbWaitingPerShard := 0 waiting := generateValidatorMap(nbWaitingPerShard, currentNbShards) newNodes := generateValidatorList(0) @@ -1154,47 +1192,17 @@ func TestRandHashShuffler_computeNewShardsWithMerge(t *testing.T) { numNewNodes := len(newNodes) numLeaving := len(leavingUnstake) + len(leavingRating) - newNbShards := shuffler.computeNewShards(eligible, waiting, numNewNodes, numLeaving, currentNbShards) + newNbShards := shuffler.computeNewShards(getTestChainParameters(), eligible, waiting, numNewNodes, numLeaving, currentNbShards) assert.Equal(t, currentNbShards-1, newNbShards) } -func TestRandHashShuffler_UpdateParams(t *testing.T) { - t.Parallel() - - shuffler, err := createHashShufflerInter() - require.Nil(t, err) - - shuffler2 := &randHashShuffler{ - nodesShard: 200, - nodesMeta: 200, - shardHysteresis: 0, - metaHysteresis: 0, - adaptivity: true, - shuffleBetweenShards: true, - validatorDistributor: &CrossShardValidatorDistributor{}, - availableNodesConfigs: nil, - stakingV4Step2EnableEpoch: 443, - stakingV4Step3EnableEpoch: 444, - enableEpochsHandler: &mock.EnableEpochsHandlerMock{}, - } - - shuffler.UpdateParams( - shuffler2.nodesShard, - shuffler2.nodesMeta, - 0, - shuffler2.adaptivity, - ) - - assert.Equal(t, shuffler2, shuffler) -} - func TestRandHashShuffler_UpdateNodeListsNoReSharding(t *testing.T) { t.Parallel() shuffler, err := createHashShufflerInter() require.Nil(t, err) - eligiblePerShard := int(shuffler.nodesShard) + eligiblePerShard := int(getTestChainParameters().ShardMinNumNodes) waitingPerShard := 30 nbShards := uint32(3) randomness := generateRandomByteArray(32) @@ -1215,6 +1223,12 @@ func TestRandHashShuffler_UpdateNodeListsNoReSharding(t *testing.T) { Rand: randomness, NbShards: nbShards, } + args.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() resUpdateNodeList, err := shuffler.UpdateNodeLists(args) require.Nil(t, err) @@ -1234,10 +1248,6 @@ func TestRandHashShuffler_UpdateNodeListsWithUnstakeLeavingRemovesFromEligible(t eligibleMeta := 10 shufflerArgs := &NodesShufflerArgs{ - NodesShard: uint32(eligiblePerShard), - NodesMeta: uint32(eligibleMeta), - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -1256,6 +1266,13 @@ func TestRandHashShuffler_UpdateNodeListsWithUnstakeLeavingRemovesFromEligible(t args.Eligible[core.MetachainShardId][1], } + args.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligibleMeta), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() + resUpdateNodeList, err := shuffler.UpdateNodeLists(args) require.Nil(t, err) @@ -1291,10 +1308,6 @@ func testUpdateNodesAndCheckNumLeaving(t *testing.T, beforeFix bool) { numNodesToShuffle := 80 shufflerArgs := &NodesShufflerArgs{ - NodesShard: uint32(eligiblePerShard), - NodesMeta: uint32(eligibleMeta), - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: []config.MaxNodesChangeConfig{ { @@ -1315,6 +1328,13 @@ func testUpdateNodesAndCheckNumLeaving(t *testing.T, beforeFix bool) { args.UnStakeLeaving = append(args.UnStakeLeaving, args.Waiting[0][i]) } + args.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligibleMeta), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() + resUpdateNodeList, err := shuffler.UpdateNodeLists(args) require.Nil(t, err) @@ -1342,10 +1362,6 @@ func TestRandHashShuffler_UpdateNodeListsAndCheckWaitingList(t *testing.T) { numNodesToShuffle := 80 shufflerArgs := &NodesShufflerArgs{ - NodesShard: uint32(eligiblePerShard), - NodesMeta: uint32(eligibleMeta), - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: []config.MaxNodesChangeConfig{ { @@ -1372,6 +1388,13 @@ func TestRandHashShuffler_UpdateNodeListsAndCheckWaitingList(t *testing.T) { args.UnStakeLeaving = append(args.UnStakeLeaving, args.Waiting[0][i]) } + args.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligibleMeta), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() + resUpdateNodeList, err := shuffler.UpdateNodeLists(args) require.Nil(t, err) @@ -1400,10 +1423,6 @@ func TestRandHashShuffler_UpdateNodeListsWithUnstakeLeavingRemovesFromWaiting(t eligibleMeta := 10 shufflerArgs := &NodesShufflerArgs{ - NodesShard: uint32(eligiblePerShard), - NodesMeta: uint32(eligibleMeta), - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -1422,6 +1441,13 @@ func TestRandHashShuffler_UpdateNodeListsWithUnstakeLeavingRemovesFromWaiting(t args.Waiting[core.MetachainShardId][1], } + args.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligibleMeta), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() + resUpdateNodeList, err := shuffler.UpdateNodeLists(args) require.Nil(t, err) @@ -1443,10 +1469,6 @@ func TestRandHashShuffler_UpdateNodeListsWithNonExistentUnstakeLeavingDoesNotRem t.Parallel() shufflerArgs := &NodesShufflerArgs{ - NodesShard: 10, - NodesMeta: 10, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -1454,7 +1476,7 @@ func TestRandHashShuffler_UpdateNodeListsWithNonExistentUnstakeLeavingDoesNotRem shuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - eligiblePerShard := int(shuffler.nodesShard) + eligiblePerShard := 10 waitingPerShard := 2 nbShards := uint32(0) @@ -1469,6 +1491,12 @@ func TestRandHashShuffler_UpdateNodeListsWithNonExistentUnstakeLeavingDoesNotRem }, } + args.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(10), + numNodesMeta: uint32(10), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() resUpdateNodeList, err := shuffler.UpdateNodeLists(args) require.Nil(t, err) @@ -1495,10 +1523,6 @@ func TestRandHashShuffler_UpdateNodeListsWithRangeOnMaps(t *testing.T) { for _, shuffle := range shuffleBetweenShards { shufflerArgs := &NodesShufflerArgs{ - NodesShard: uint32(eligiblePerShard), - NodesMeta: uint32(eligiblePerShard), - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffle, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -1534,6 +1558,13 @@ func TestRandHashShuffler_UpdateNodeListsWithRangeOnMaps(t *testing.T) { args.UnStakeLeaving = leavingValidators + args.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() + resUpdateNodeListInitial, err := shuffler.UpdateNodeLists(args) require.Nil(t, err) @@ -1573,6 +1604,12 @@ func TestRandHashShuffler_UpdateNodeListsNoReShardingIntraShardShuffling(t *test Rand: randomness, NbShards: nbShards, } + args.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() resUpdateNodeList, err := shuffler.UpdateNodeLists(args) require.Nil(t, err) @@ -1937,6 +1974,12 @@ func TestRandHashShuffler_UpdateNodeLists_WithUnstakeLeaving(t *testing.T) { Rand: generateRandomByteArray(32), NbShards: nbShards, } + arg.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() result, err := shuffler.UpdateNodeLists(arg) require.Nil(t, err) @@ -1986,6 +2029,12 @@ func TestRandHashShuffler_UpdateNodeLists_WithUnstakeLeaving_EnoughRemaining(t * Rand: generateRandomByteArray(32), NbShards: nbShards, } + arg.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() result, err := shuffler.UpdateNodeLists(arg) assert.NotNil(t, result) @@ -2018,6 +2067,12 @@ func TestRandHashShuffler_UpdateNodeLists_WithUnstakeLeaving_NotEnoughRemaining( Rand: generateRandomByteArray(32), NbShards: nbShards, } + arg.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() _, err = shuffler.UpdateNodeLists(arg) assert.True(t, errors.Is(err, ErrSmallShardEligibleListSize)) @@ -2039,6 +2094,12 @@ func TestRandHashShuffler_UpdateNodeLists_WithUnstakeLeaving_NotEnoughRemaining( Rand: generateRandomByteArray(32), NbShards: uint32(len(eligibleMap)), } + arg.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() _, err = shuffler.UpdateNodeLists(arg) assert.True(t, errors.Is(err, ErrSmallShardEligibleListSize)) @@ -2086,6 +2147,12 @@ func TestRandHashShuffler_UpdateNodeLists_WithAdditionalLeaving(t *testing.T) { Rand: generateRandomByteArray(32), NbShards: nbShards, } + arg.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() result, err := shuffler.UpdateNodeLists(arg) require.Nil(t, err) @@ -2110,7 +2177,7 @@ func TestRandHashShuffler_UpdateNodeLists_WithAdditionalLeaving(t *testing.T) { } } -func TestRandHashShuffler_UpdateNodeLists_WithUnstakeAndAdditionalLeaving_NoDupplicates(t *testing.T) { +func TestRandHashShuffler_UpdateNodeLists_WithUnstakeAndAdditionalLeaving_NoDuplicates(t *testing.T) { t.Parallel() eligiblePerShard := 100 @@ -2162,6 +2229,12 @@ func TestRandHashShuffler_UpdateNodeLists_WithUnstakeAndAdditionalLeaving_NoDupp Rand: generateRandomByteArray(32), NbShards: nbShards, } + arg.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() result, err := shuffler.UpdateNodeLists(arg) require.Nil(t, err) @@ -2185,7 +2258,7 @@ func TestRandHashShuffler_UpdateNodeLists_WithUnstakeAndAdditionalLeaving_NoDupp ) } } -func TestRandHashShuffler_UpdateNodeLists_WithAdditionalLeaving_WithDupplicates(t *testing.T) { +func TestRandHashShuffler_UpdateNodeLists_WithAdditionalLeaving_WithDuplicates(t *testing.T) { t.Parallel() eligiblePerShard := 100 @@ -2244,6 +2317,12 @@ func TestRandHashShuffler_UpdateNodeLists_WithAdditionalLeaving_WithDupplicates( Rand: generateRandomByteArray(32), NbShards: nbShards, } + arg.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() result, err := shuffler.UpdateNodeLists(arg) require.Nil(t, err) @@ -2316,10 +2395,6 @@ func TestRandHashShuffler_UpdateNodeLists_All(t *testing.T) { unstakeLeavingList, additionalLeavingList := prepareListsFromMaps(unstakeLeaving, additionalLeaving) shufflerArgs := &NodesShufflerArgs{ - NodesShard: uint32(eligiblePerShard), - NodesMeta: uint32(eligiblePerShard), - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, EnableEpochs: config.EnableEpochs{ StakingV4Step2EnableEpoch: 443, @@ -2339,6 +2414,12 @@ func TestRandHashShuffler_UpdateNodeLists_All(t *testing.T) { Rand: generateRandomByteArray(32), NbShards: nbShards, } + arg.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() result, err := shuffler.UpdateNodeLists(arg) require.Nil(t, err) @@ -2423,10 +2504,6 @@ func TestRandHashShuffler_UpdateNodeLists_WithNewNodes_NoWaiting(t *testing.T) { } shufflerArgs := &NodesShufflerArgs{ - NodesShard: uint32(eligiblePerShard), - NodesMeta: uint32(eligiblePerShard), - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -2435,6 +2512,12 @@ func TestRandHashShuffler_UpdateNodeLists_WithNewNodes_NoWaiting(t *testing.T) { shuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) + args.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() resUpdateNodeList, err := shuffler.UpdateNodeLists(args) require.Nil(t, err) @@ -2485,10 +2568,6 @@ func TestRandHashShuffler_UpdateNodeLists_WithNewNodes_NilOrEmptyWaiting(t *test } shufflerArgs := &NodesShufflerArgs{ - NodesShard: uint32(eligiblePerShard), - NodesMeta: uint32(eligiblePerShard), - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -2497,6 +2576,13 @@ func TestRandHashShuffler_UpdateNodeLists_WithNewNodes_NilOrEmptyWaiting(t *test shuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) + args.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() + resUpdateNodeList, err := shuffler.UpdateNodeLists(args) require.Nil(t, err) require.Equal(t, int(nbShards+1), len(resUpdateNodeList.Waiting)) @@ -2511,6 +2597,13 @@ func TestRandHashShuffler_UpdateNodeLists_WithNewNodes_NilOrEmptyWaiting(t *test NbShards: nbShards, } + args.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() + resUpdateNodeList, err = shuffler.UpdateNodeLists(args) require.Nil(t, err) require.Equal(t, int(nbShards+1), len(resUpdateNodeList.Waiting)) @@ -2541,6 +2634,12 @@ func TestRandHashShuffler_UpdateNodeLists_WithNewNodes_WithWaiting(t *testing.T) Rand: randomness, NbShards: uint32(nbShards), } + args.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() shuffler, err := createHashShufflerIntraShards() require.Nil(t, err) @@ -2586,6 +2685,7 @@ func TestRandHashShuffler_UpdateNodeLists_WithStakingV4(t *testing.T) { Auction: auctionList, NbShards: nbShards, Epoch: stakingV4Epoch, + ChainParameters: getTestChainParameters(), } shuffler, _ := createHashShufflerIntraShards() @@ -2668,10 +2768,6 @@ func TestRandHashShuffler_UpdateNodeLists_WithNewNodes_WithWaiting_WithLeaving(t } shufflerArgs := &NodesShufflerArgs{ - NodesShard: uint32(numEligiblePerShard), - NodesMeta: uint32(numEligiblePerShard), - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, EnableEpochs: config.EnableEpochs{ StakingV4Step2EnableEpoch: 443, @@ -2682,6 +2778,12 @@ func TestRandHashShuffler_UpdateNodeLists_WithNewNodes_WithWaiting_WithLeaving(t shuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) + args.ChainParameters = testChainParametersCreator{ + numNodesShards: uint32(numEligiblePerShard), + numNodesMeta: uint32(numEligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() resUpdateNodeList, err := shuffler.UpdateNodeLists(args) require.Nil(t, err) @@ -2723,7 +2825,7 @@ func verifyResultsIntraShardShuffling( initialNumWaiting := len(waiting) numToRemove := initialNumWaiting - additionalLeaving = removeDupplicates(unstakeLeaving, additionalLeaving) + additionalLeaving = removeDuplicates(unstakeLeaving, additionalLeaving) computedNewWaiting, removedFromWaiting := removeValidatorsFromList(waiting, unstakeLeaving, numToRemove) removedNodes = append(removedNodes, removedFromWaiting...) @@ -2899,10 +3001,6 @@ func TestRandHashShuffler_sortConfigs(t *testing.T) { require.NotEqual(t, orderedConfigs, shuffledConfigs) shufflerArgs := &NodesShufflerArgs{ - NodesShard: eligiblePerShard, - NodesMeta: eligiblePerShard, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: shuffledConfigs, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -2918,10 +3016,6 @@ func TestRandHashShuffler_UpdateShufflerConfig(t *testing.T) { orderedConfigs := getDummyShufflerConfigs() shufflerArgs := &NodesShufflerArgs{ - NodesShard: eligiblePerShard, - NodesMeta: eligiblePerShard, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: orderedConfigs, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -2934,7 +3028,13 @@ func TestRandHashShuffler_UpdateShufflerConfig(t *testing.T) { if epoch == orderedConfigs[(i+1)%len(orderedConfigs)].EpochEnable { i++ } - shuffler.updateShufflerConfig(epoch) + chainParams := testChainParametersCreator{ + numNodesShards: uint32(eligiblePerShard), + numNodesMeta: uint32(eligiblePerShard), + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build().CurrentChainParameters() + shuffler.UpdateShufflerConfig(epoch, chainParams) require.Equal(t, orderedConfigs[i], shuffler.activeNodesConfig) } } diff --git a/sharding/nodesCoordinator/indexHashedNodesCoordinator.go b/sharding/nodesCoordinator/indexHashedNodesCoordinator.go index 4898f018010..9d1fe16551a 100644 --- a/sharding/nodesCoordinator/indexHashedNodesCoordinator.go +++ b/sharding/nodesCoordinator/indexHashedNodesCoordinator.go @@ -30,6 +30,7 @@ const ( keyFormat = "%s_%v_%v_%v" defaultSelectionChances = uint32(1) minEpochsToWait = uint32(1) + leaderSelectionSize = 1 ) // TODO: move this to config parameters @@ -40,6 +41,12 @@ type validatorWithShardID struct { shardID uint32 } +// savedConsensusGroup holds the leader and consensus group for a specific selection +type savedConsensusGroup struct { + leader Validator + consensusGroup []Validator +} + type validatorList []Validator // Len will return the length of the validatorList @@ -75,8 +82,7 @@ type epochNodesConfig struct { type indexHashedNodesCoordinator struct { shardIDAsObserver uint32 currentEpoch uint32 - shardConsensusGroupSize int - metaConsensusGroupSize int + chainParametersHandler ChainParametersHandler numTotalEligible uint64 selfPubKey []byte savedStateKey []byte @@ -140,8 +146,7 @@ func NewIndexHashedNodesCoordinator(arguments ArgNodesCoordinator) (*indexHashed nodesConfig: nodesConfig, currentEpoch: arguments.Epoch, savedStateKey: savedKey, - shardConsensusGroupSize: arguments.ShardConsensusGroupSize, - metaConsensusGroupSize: arguments.MetaConsensusGroupSize, + chainParametersHandler: arguments.ChainParametersHandler, consensusGroupCacher: arguments.ConsensusGroupCache, shardIDAsObserver: arguments.ShardIDAsObserver, shuffledOutHandler: arguments.ShuffledOutHandler, @@ -195,8 +200,8 @@ func NewIndexHashedNodesCoordinator(arguments ArgNodesCoordinator) (*indexHashed } func checkArguments(arguments ArgNodesCoordinator) error { - if arguments.ShardConsensusGroupSize < 1 || arguments.MetaConsensusGroupSize < 1 { - return ErrInvalidConsensusGroupSize + if check.IfNil(arguments.ChainParametersHandler) { + return ErrNilChainParametersHandler } if arguments.NbShards < 1 { return ErrInvalidNumberOfShards @@ -278,21 +283,25 @@ func (ihnc *indexHashedNodesCoordinator) setNodesPerShards( return ErrNilInputNodesMap } + currentChainParameters, err := ihnc.chainParametersHandler.ChainParametersForEpoch(epoch) + if err != nil { + return err + } + nodesList := eligible[core.MetachainShardId] - if len(nodesList) < ihnc.metaConsensusGroupSize { + if len(nodesList) < int(currentChainParameters.MetachainConsensusGroupSize) { return ErrSmallMetachainEligibleListSize } numTotalEligible := uint64(len(nodesList)) for shardId := uint32(0); shardId < uint32(len(eligible)-1); shardId++ { nbNodesShard := len(eligible[shardId]) - if nbNodesShard < ihnc.shardConsensusGroupSize { + if nbNodesShard < int(currentChainParameters.ShardConsensusGroupSize) { return ErrSmallShardEligibleListSize } numTotalEligible += uint64(nbNodesShard) } - var err error var isCurrentNodeValidator bool // nbShards holds number of shards without meta nodesConfig.nbShards = uint32(len(eligible) - 1) @@ -344,7 +353,7 @@ func (ihnc *indexHashedNodesCoordinator) ComputeConsensusGroup( round uint64, shardID uint32, epoch uint32, -) (validatorsGroup []Validator, err error) { +) (leader Validator, validatorsGroup []Validator, err error) { var selector RandomSelector var eligibleList []Validator @@ -355,7 +364,7 @@ func (ihnc *indexHashedNodesCoordinator) ComputeConsensusGroup( "round", round) if len(randomness) == 0 { - return nil, ErrNilRandomness + return nil, nil, ErrNilRandomness } ihnc.mutNodesConfig.RLock() @@ -364,7 +373,7 @@ func (ihnc *indexHashedNodesCoordinator) ComputeConsensusGroup( if shardID >= nodesConfig.nbShards && shardID != core.MetachainShardId { log.Warn("shardID is not ok", "shardID", shardID, "nbShards", nodesConfig.nbShards) ihnc.mutNodesConfig.RUnlock() - return nil, ErrInvalidShardId + return nil, nil, ErrInvalidShardId } selector = nodesConfig.selectors[shardID] eligibleList = nodesConfig.eligibleMap[shardID] @@ -372,16 +381,16 @@ func (ihnc *indexHashedNodesCoordinator) ComputeConsensusGroup( ihnc.mutNodesConfig.RUnlock() if !ok { - return nil, fmt.Errorf("%w epoch=%v", ErrEpochNodesConfigDoesNotExist, epoch) + return nil, nil, fmt.Errorf("%w epoch=%v", ErrEpochNodesConfigDoesNotExist, epoch) } key := []byte(fmt.Sprintf(keyFormat, string(randomness), round, shardID, epoch)) - validators := ihnc.searchConsensusForKey(key) - if validators != nil { - return validators, nil + savedCG := ihnc.searchConsensusForKey(key) + if savedCG != nil { + return savedCG.leader, savedCG.consensusGroup, nil } - consensusSize := ihnc.ConsensusGroupSize(shardID) + consensusSize := ihnc.ConsensusGroupSizeForShardAndEpoch(shardID, epoch) randomness = []byte(fmt.Sprintf("%d-%s", round, randomness)) log.Debug("computeValidatorsGroup", @@ -392,27 +401,59 @@ func (ihnc *indexHashedNodesCoordinator) ComputeConsensusGroup( "round", round, "shardID", shardID) - tempList, err := selectValidators(selector, randomness, uint32(consensusSize), eligibleList) + leader, validatorsGroup, err = ihnc.selectLeaderAndConsensusGroup(selector, randomness, eligibleList, consensusSize, epoch) if err != nil { - return nil, err + return nil, nil, err } - size := 0 - for _, v := range tempList { - size += v.Size() + ihnc.cacheConsensusGroup(key, validatorsGroup, leader) + + return leader, validatorsGroup, nil +} + +func (ihnc *indexHashedNodesCoordinator) cacheConsensusGroup(key []byte, consensusGroup []Validator, leader Validator) { + size := leader.Size() * len(consensusGroup) + savedCG := &savedConsensusGroup{ + leader: leader, + consensusGroup: consensusGroup, } + ihnc.consensusGroupCacher.Put(key, savedCG, size) +} - ihnc.consensusGroupCacher.Put(key, tempList, size) +func (ihnc *indexHashedNodesCoordinator) selectLeaderAndConsensusGroup( + selector RandomSelector, + randomness []byte, + eligibleList []Validator, + consensusSize int, + epoch uint32, +) (Validator, []Validator, error) { + leaderPositionInSelection := 0 + if !ihnc.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, epoch) { + tempList, err := selectValidators(selector, randomness, uint32(consensusSize), eligibleList) + if err != nil { + return nil, nil, err + } + + if len(tempList) == 0 { + return nil, nil, ErrEmptyValidatorsList + } + + return tempList[leaderPositionInSelection], tempList, nil + } - return tempList, nil + selectedValidators, err := selectValidators(selector, randomness, leaderSelectionSize, eligibleList) + if err != nil { + return nil, nil, err + } + return selectedValidators[leaderPositionInSelection], eligibleList, nil } -func (ihnc *indexHashedNodesCoordinator) searchConsensusForKey(key []byte) []Validator { +func (ihnc *indexHashedNodesCoordinator) searchConsensusForKey(key []byte) *savedConsensusGroup { value, ok := ihnc.consensusGroupCacher.Get(key) if ok { - consensusGroup, typeOk := value.([]Validator) + savedCG, typeOk := value.(*savedConsensusGroup) if typeOk { - return consensusGroup + return savedCG } } return nil @@ -440,10 +481,10 @@ func (ihnc *indexHashedNodesCoordinator) GetConsensusValidatorsPublicKeys( round uint64, shardID uint32, epoch uint32, -) ([]string, error) { - consensusNodes, err := ihnc.ComputeConsensusGroup(randomness, round, shardID, epoch) +) (string, []string, error) { + leader, consensusNodes, err := ihnc.ComputeConsensusGroup(randomness, round, shardID, epoch) if err != nil { - return nil, err + return "", nil, err } pubKeys := make([]string, 0) @@ -452,7 +493,29 @@ func (ihnc *indexHashedNodesCoordinator) GetConsensusValidatorsPublicKeys( pubKeys = append(pubKeys, string(v.PubKey())) } - return pubKeys, nil + return string(leader.PubKey()), pubKeys, nil +} + +// GetAllEligibleValidatorsPublicKeysForShard will return all validators public keys for the provided shard +func (ihnc *indexHashedNodesCoordinator) GetAllEligibleValidatorsPublicKeysForShard(epoch uint32, shardID uint32) ([]string, error) { + ihnc.mutNodesConfig.RLock() + nodesConfig, ok := ihnc.nodesConfig[epoch] + ihnc.mutNodesConfig.RUnlock() + + if !ok { + return nil, fmt.Errorf("%w epoch=%v", ErrEpochNodesConfigDoesNotExist, epoch) + } + + nodesConfig.mutNodesMaps.RLock() + defer nodesConfig.mutNodesMaps.RUnlock() + + shardEligible := nodesConfig.eligibleMap[shardID] + validatorsPubKeys := make([]string, 0, len(shardEligible)) + for i := 0; i < len(shardEligible); i++ { + validatorsPubKeys = append(validatorsPubKeys, string(shardEligible[i].PubKey())) + } + + return validatorsPubKeys, nil } // GetAllEligibleValidatorsPublicKeys will return all validators public keys for all shards @@ -617,6 +680,19 @@ func (ihnc *indexHashedNodesCoordinator) GetValidatorsIndexes( return signersIndexes, nil } +// GetCachedEpochs returns all epochs cached +func (ihnc *indexHashedNodesCoordinator) GetCachedEpochs() map[uint32]struct{} { + cachedEpochs := make(map[uint32]struct{}, nodesCoordinatorStoredEpochs) + + ihnc.mutNodesConfig.RLock() + for epoch := range ihnc.nodesConfig { + cachedEpochs[epoch] = struct{}{} + } + ihnc.mutNodesConfig.RUnlock() + + return cachedEpochs +} + // EpochStartPrepare is called when an epoch start event is observed, but not yet confirmed/committed. // Some components may need to do some initialisation on this event func (ihnc *indexHashedNodesCoordinator) EpochStartPrepare(metaHdr data.HeaderHandler, body data.BodyHandler) { @@ -663,7 +739,14 @@ func (ihnc *indexHashedNodesCoordinator) EpochStartPrepare(metaHdr data.HeaderHa unStakeLeavingList := ihnc.createSortedListFromMap(newNodesConfig.leavingMap) additionalLeavingList := ihnc.createSortedListFromMap(additionalLeavingMap) + chainParamsForEpoch, err := ihnc.chainParametersHandler.ChainParametersForEpoch(newEpoch) + if err != nil { + log.Warn("indexHashedNodesCoordinator.EpochStartPrepare: could not compute chain params for epoch. "+ + "Will use the current chain parameters", "epoch", newEpoch, "error", err) + chainParamsForEpoch = ihnc.chainParametersHandler.CurrentChainParameters() + } shufflerArgs := ArgsUpdateNodes{ + ChainParameters: chainParamsForEpoch, Eligible: newNodesConfig.eligibleMap, Waiting: newNodesConfig.waitingMap, NewNodes: newNodesConfig.newList, @@ -1134,15 +1217,12 @@ func (ihnc *indexHashedNodesCoordinator) computeShardForSelfPublicKey(nodesConfi return selfShard, false } -// ConsensusGroupSize returns the consensus group size for a specific shard -func (ihnc *indexHashedNodesCoordinator) ConsensusGroupSize( +// ConsensusGroupSizeForShardAndEpoch returns the consensus group size for a specific shard in a given epoch +func (ihnc *indexHashedNodesCoordinator) ConsensusGroupSizeForShardAndEpoch( shardID uint32, + epoch uint32, ) int { - if shardID == core.MetachainShardId { - return ihnc.metaConsensusGroupSize - } - - return ihnc.shardConsensusGroupSize + return common.ConsensusGroupSizeForShardAndEpoch(log, ihnc.chainParametersHandler, shardID, epoch) } // GetNumTotalEligible returns the number of total eligible accross all shards from current setup @@ -1245,7 +1325,7 @@ func computeActuallyLeaving( func selectValidators( selector RandomSelector, randomness []byte, - consensusSize uint32, + selectionSize uint32, eligibleList []Validator, ) ([]Validator, error) { if check.IfNil(selector) { @@ -1256,19 +1336,19 @@ func selectValidators( } // todo: checks for indexes - selectedIndexes, err := selector.Select(randomness, consensusSize) + selectedIndexes, err := selector.Select(randomness, selectionSize) if err != nil { return nil, err } - consensusGroup := make([]Validator, consensusSize) - for i := range consensusGroup { - consensusGroup[i] = eligibleList[selectedIndexes[i]] + selectedValidators := make([]Validator, selectionSize) + for i := range selectedValidators { + selectedValidators[i] = eligibleList[selectedIndexes[i]] } - displayValidatorsForRandomness(consensusGroup, randomness) + displayValidatorsForRandomness(selectedValidators, randomness) - return consensusGroup, nil + return selectedValidators, nil } // createValidatorInfoFromBody unmarshalls body data to create validator info diff --git a/sharding/nodesCoordinator/indexHashedNodesCoordinatorLite.go b/sharding/nodesCoordinator/indexHashedNodesCoordinatorLite.go index b5b87781a73..46564c6486c 100644 --- a/sharding/nodesCoordinator/indexHashedNodesCoordinatorLite.go +++ b/sharding/nodesCoordinator/indexHashedNodesCoordinatorLite.go @@ -19,7 +19,13 @@ func (ihnc *indexHashedNodesCoordinator) SetNodesConfigFromValidatorsInfo(epoch unStakeLeavingList := ihnc.createSortedListFromMap(newNodesConfig.leavingMap) additionalLeavingList := ihnc.createSortedListFromMap(additionalLeavingMap) + chainParameters, err := ihnc.chainParametersHandler.ChainParametersForEpoch(epoch) + if err != nil { + return err + } + shufflerArgs := ArgsUpdateNodes{ + ChainParameters: chainParameters, Eligible: newNodesConfig.eligibleMap, Waiting: newNodesConfig.waitingMap, NewNodes: newNodesConfig.newList, diff --git a/sharding/nodesCoordinator/indexHashedNodesCoordinatorLite_test.go b/sharding/nodesCoordinator/indexHashedNodesCoordinatorLite_test.go index e880d564ca2..b54fed7860d 100644 --- a/sharding/nodesCoordinator/indexHashedNodesCoordinatorLite_test.go +++ b/sharding/nodesCoordinator/indexHashedNodesCoordinatorLite_test.go @@ -84,10 +84,13 @@ func TestIndexHashedNodesCoordinator_SetNodesConfigFromValidatorsInfo(t *testing t.Parallel() arguments := createArguments() - + arguments.ChainParametersHandler = testChainParametersCreator{ + numNodesShards: 3, + numNodesMeta: 3, + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build() shufflerArgs := &NodesShufflerArgs{ - NodesShard: 3, - NodesMeta: 3, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, } nodeShuffler, _ := NewHashValidatorsShuffler(shufflerArgs) @@ -108,10 +111,14 @@ func TestIndexHashedNodesCoordinator_SetNodesConfigFromValidatorsInfoMultipleEpo t.Parallel() arguments := createArguments() + arguments.ChainParametersHandler = testChainParametersCreator{ + numNodesShards: 3, + numNodesMeta: 3, + hysteresis: hysteresis, + adaptivity: adaptivity, + }.build() shufflerArgs := &NodesShufflerArgs{ - NodesShard: 3, - NodesMeta: 3, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, } nodeShuffler, _ := NewHashValidatorsShuffler(shufflerArgs) diff --git a/sharding/nodesCoordinator/indexHashedNodesCoordinatorWithRater_test.go b/sharding/nodesCoordinator/indexHashedNodesCoordinatorWithRater_test.go index a80006cceae..1154d93ae1a 100644 --- a/sharding/nodesCoordinator/indexHashedNodesCoordinatorWithRater_test.go +++ b/sharding/nodesCoordinator/indexHashedNodesCoordinatorWithRater_test.go @@ -17,8 +17,11 @@ import ( "github.com/stretchr/testify/require" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/sharding/mock" "github.com/multiversx/mx-chain-go/state" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" + testscommonConsensus "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" @@ -65,10 +68,6 @@ func TestIndexHashedGroupSelectorWithRater_OkValShouldWork(t *testing.T) { eligibleMap := createDummyNodesMap(3, 1, "waiting") waitingMap := make(map[uint32][]Validator) shufflerArgs := &NodesShufflerArgs{ - NodesShard: 3, - NodesMeta: 3, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -76,12 +75,18 @@ func TestIndexHashedGroupSelectorWithRater_OkValShouldWork(t *testing.T) { nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 2, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 2, + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -138,8 +143,9 @@ func TestIndexHashedGroupSelectorWithRater_ComputeValidatorsGroup1ValidatorShoul assert.Equal(t, false, chancesCalled) ihnc, _ := NewIndexHashedNodesCoordinatorWithRater(nc, rater) assert.Equal(t, true, chancesCalled) - list2, err := ihnc.ComputeConsensusGroup([]byte("randomness"), 0, 0, 0) + leader, list2, err := ihnc.ComputeConsensusGroup([]byte("randomness"), 0, 0, 0) + assert.Equal(t, list[0], leader) assert.Nil(t, err) assert.Equal(t, 1, len(list2)) } @@ -165,22 +171,24 @@ func BenchmarkIndexHashedGroupSelectorWithRater_ComputeValidatorsGroup63of400(b eligibleMap[core.MetachainShardId] = listMeta shufflerArgs := &NodesShufflerArgs{ - NodesShard: 400, - NodesMeta: 1, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, } nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(b, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: consensusGroupSize, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(consensusGroupSize), + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -207,7 +215,7 @@ func BenchmarkIndexHashedGroupSelectorWithRater_ComputeValidatorsGroup63of400(b for i := 0; i < b.N; i++ { randomness := strconv.Itoa(0) - list2, _ := ihncRater.ComputeConsensusGroup([]byte(randomness), uint64(0), 0, 0) + _, list2, _ := ihncRater.ComputeConsensusGroup([]byte(randomness), uint64(0), 0, 0) assert.Equal(b, consensusGroupSize, len(list2)) } @@ -241,10 +249,6 @@ func Test_ComputeValidatorsGroup63of400(t *testing.T) { eligibleMap[0] = list eligibleMap[core.MetachainShardId] = listMeta shufflerArgs := &NodesShufflerArgs{ - NodesShard: shardSize, - NodesMeta: 1, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -252,12 +256,18 @@ func Test_ComputeValidatorsGroup63of400(t *testing.T) { nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: consensusGroupSize, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(consensusGroupSize), + MetachainConsensusGroupSize: 1, + } + }, + }, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, EpochStartNotifier: epochStartSubscriber, @@ -279,8 +289,8 @@ func Test_ComputeValidatorsGroup63of400(t *testing.T) { hasher := sha256.NewSha256() for i := uint64(0); i < numRounds; i++ { randomness := hasher.Compute(fmt.Sprintf("%v%v", i, time.Millisecond)) - consensusGroup, _ := ihnc.ComputeConsensusGroup(randomness, uint64(0), 0, 0) - leaderAppearances[string(consensusGroup[0].PubKey())]++ + leader, consensusGroup, _ := ihnc.ComputeConsensusGroup(randomness, uint64(0), 0, 0) + leaderAppearances[string(leader.PubKey())]++ for _, v := range consensusGroup { consensusAppearances[string(v.PubKey())]++ } @@ -315,10 +325,6 @@ func TestIndexHashedGroupSelectorWithRater_GetValidatorWithPublicKeyShouldReturn eligibleMap[0] = list eligibleMap[core.MetachainShardId] = list sufflerArgs := &NodesShufflerArgs{ - NodesShard: 1, - NodesMeta: 1, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -326,12 +332,18 @@ func TestIndexHashedGroupSelectorWithRater_GetValidatorWithPublicKeyShouldReturn nodeShuffler, err := NewHashValidatorsShuffler(sufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 1, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -371,10 +383,6 @@ func TestIndexHashedGroupSelectorWithRater_GetValidatorWithPublicKeyShouldReturn eligibleMap[core.MetachainShardId] = list shufflerArgs := &NodesShufflerArgs{ - NodesShard: 1, - NodesMeta: 1, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -382,12 +390,18 @@ func TestIndexHashedGroupSelectorWithRater_GetValidatorWithPublicKeyShouldReturn nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 1, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -437,10 +451,6 @@ func TestIndexHashedGroupSelectorWithRater_GetValidatorWithPublicKeyShouldWork(t waitingMap := make(map[uint32][]Validator) shufflerArgs := &NodesShufflerArgs{ - NodesShard: 3, - NodesMeta: 3, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -448,7 +458,7 @@ func TestIndexHashedGroupSelectorWithRater_GetValidatorWithPublicKeyShouldWork(t nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() eligibleMap[core.MetachainShardId] = listMeta @@ -456,8 +466,14 @@ func TestIndexHashedGroupSelectorWithRater_GetValidatorWithPublicKeyShouldWork(t eligibleMap[1] = listShard1 arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 1, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -524,17 +540,13 @@ func TestIndexHashedGroupSelectorWithRater_GetAllEligibleValidatorsPublicKeys(t waitingMap := make(map[uint32][]Validator) shufflerArgs := &NodesShufflerArgs{ - NodesShard: 3, - NodesMeta: 3, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, } nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() eligibleMap[core.MetachainShardId] = listMeta @@ -542,8 +554,14 @@ func TestIndexHashedGroupSelectorWithRater_GetAllEligibleValidatorsPublicKeys(t eligibleMap[shardOneId] = listShard1 arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 1, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -836,10 +854,6 @@ func BenchmarkIndexHashedWithRaterGroupSelector_ComputeValidatorsGroup21of400(b eligibleMap[0] = list eligibleMap[core.MetachainShardId] = listMeta shufflerArgs := &NodesShufflerArgs{ - NodesShard: 400, - NodesMeta: 1, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -847,12 +861,18 @@ func BenchmarkIndexHashedWithRaterGroupSelector_ComputeValidatorsGroup21of400(b nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(b, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: consensusGroupSize, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(consensusGroupSize), + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -880,7 +900,7 @@ func BenchmarkIndexHashedWithRaterGroupSelector_ComputeValidatorsGroup21of400(b for i := 0; i < b.N; i++ { randomness := strconv.Itoa(i) - list2, _ := ihncRater.ComputeConsensusGroup([]byte(randomness), 0, 0, 0) + _, list2, _ := ihncRater.ComputeConsensusGroup([]byte(randomness), 0, 0, 0) assert.Equal(b, consensusGroupSize, len(list2)) } diff --git a/sharding/nodesCoordinator/indexHashedNodesCoordinator_test.go b/sharding/nodesCoordinator/indexHashedNodesCoordinator_test.go index 32cc2ca8326..713aeac1643 100644 --- a/sharding/nodesCoordinator/indexHashedNodesCoordinator_test.go +++ b/sharding/nodesCoordinator/indexHashedNodesCoordinator_test.go @@ -24,15 +24,19 @@ import ( "github.com/stretchr/testify/require" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/sharding/mock" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage/cache" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + testscommonConsensus "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/nodeTypeProviderMock" + "github.com/multiversx/mx-chain-go/testscommon/shardingMocks/nodesCoordinatorMocks" vic "github.com/multiversx/mx-chain-go/testscommon/validatorInfoCacher" ) @@ -98,35 +102,45 @@ func createArguments() ArgNodesCoordinator { eligibleMap := createDummyNodesMap(10, nbShards, "eligible") waitingMap := createDummyNodesMap(3, nbShards, "waiting") shufflerArgs := &NodesShufflerArgs{ - NodesShard: 10, - NodesMeta: 10, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, } nodeShuffler, _ := NewHashValidatorsShuffler(shufflerArgs) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 1, - MetaConsensusGroupSize: 1, - Marshalizer: &mock.MarshalizerMock{}, - Hasher: &hashingMocks.HasherMock{}, - Shuffler: nodeShuffler, - EpochStartNotifier: epochStartSubscriber, - BootStorer: bootStorer, - NbShards: nbShards, - EligibleNodes: eligibleMap, - WaitingNodes: waitingMap, - SelfPublicKey: []byte("test"), - ConsensusGroupCache: &mock.NodesCoordinatorCacheMock{}, - ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, - IsFullArchive: false, - ChanStopNode: make(chan endProcess.ArgEndProcess), - NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + ShardMinNumNodes: 10, + MetachainConsensusGroupSize: 1, + MetachainMinNumNodes: 10, + }, nil + }, + }, + Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Shuffler: nodeShuffler, + EpochStartNotifier: epochStartSubscriber, + BootStorer: bootStorer, + NbShards: nbShards, + EligibleNodes: eligibleMap, + WaitingNodes: waitingMap, + SelfPublicKey: []byte("test"), + ConsensusGroupCache: &mock.NodesCoordinatorCacheMock{}, + ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, + IsFullArchive: false, + ChanStopNode: make(chan endProcess.ArgEndProcess), + NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{ IsRefactorPeersMiniBlocksFlagEnabledField: true, }, @@ -159,14 +173,14 @@ func TestNewIndexHashedNodesCoordinator_NilHasherShouldErr(t *testing.T) { require.Nil(t, ihnc) } -func TestNewIndexHashedNodesCoordinator_InvalidConsensusGroupSizeShouldErr(t *testing.T) { +func TestNewIndexHashedNodesCoordinator_NilChainParametersHandleShouldErr(t *testing.T) { t.Parallel() arguments := createArguments() - arguments.ShardConsensusGroupSize = 0 + arguments.ChainParametersHandler = nil ihnc, err := NewIndexHashedNodesCoordinator(arguments) - require.Equal(t, ErrInvalidConsensusGroupSize, err) + require.Equal(t, ErrNilChainParametersHandler, err) require.Nil(t, ihnc) } @@ -277,10 +291,6 @@ func TestIndexHashedNodesCoordinator_OkValShouldWork(t *testing.T) { waitingMap := createDummyNodesMap(3, 3, "waiting") shufflerArgs := &NodesShufflerArgs{ - NodesShard: 10, - NodesMeta: 10, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -288,12 +298,18 @@ func TestIndexHashedNodesCoordinator_OkValShouldWork(t *testing.T) { nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 2, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 2, + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -322,27 +338,12 @@ func TestIndexHashedNodesCoordinator_OkValShouldWork(t *testing.T) { // ------- ComputeValidatorsGroup -func TestIndexHashedNodesCoordinator_NewCoordinatorGroup0SizeShouldErr(t *testing.T) { - t.Parallel() - - arguments := createArguments() - arguments.MetaConsensusGroupSize = 0 - ihnc, err := NewIndexHashedNodesCoordinator(arguments) - - require.Equal(t, ErrInvalidConsensusGroupSize, err) - require.Nil(t, ihnc) -} - func TestIndexHashedNodesCoordinator_NewCoordinatorTooFewNodesShouldErr(t *testing.T) { t.Parallel() eligibleMap := createDummyNodesMap(5, 3, "eligible") waitingMap := createDummyNodesMap(3, 3, "waiting") shufflerArgs := &NodesShufflerArgs{ - NodesShard: 10, - NodesMeta: 10, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -350,12 +351,24 @@ func TestIndexHashedNodesCoordinator_NewCoordinatorTooFewNodesShouldErr(t *testi nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 10, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 10, + MetachainConsensusGroupSize: 1, + } + }, + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 10, + MetachainConsensusGroupSize: 1, + }, nil + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -385,10 +398,11 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroupNilRandomnessShouldEr arguments := createArguments() ihnc, _ := NewIndexHashedNodesCoordinator(arguments) - list2, err := ihnc.ComputeConsensusGroup(nil, 0, 0, 0) + leader, list2, err := ihnc.ComputeConsensusGroup(nil, 0, 0, 0) require.Equal(t, ErrNilRandomness, err) require.Nil(t, list2) + require.Nil(t, leader) } func TestIndexHashedNodesCoordinator_ComputeValidatorsGroupInvalidShardIdShouldErr(t *testing.T) { @@ -396,10 +410,11 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroupInvalidShardIdShouldE arguments := createArguments() ihnc, _ := NewIndexHashedNodesCoordinator(arguments) - list2, err := ihnc.ComputeConsensusGroup([]byte("radomness"), 0, 5, 0) + leader, list2, err := ihnc.ComputeConsensusGroup([]byte("radomness"), 0, 5, 0) require.Equal(t, ErrInvalidShardId, err) require.Nil(t, list2) + require.Nil(t, leader) } // ------- functionality tests @@ -415,10 +430,6 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup1ValidatorShouldRetur nodesMap[0] = list nodesMap[core.MetachainShardId] = tmp[core.MetachainShardId] shufflerArgs := &NodesShufflerArgs{ - NodesShard: 10, - NodesMeta: 10, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -426,12 +437,24 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup1ValidatorShouldRetur nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 1, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + }, nil + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -451,10 +474,11 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup1ValidatorShouldRetur NodesCoordinatorRegistryFactory: createNodesCoordinatorRegistryFactory(), } ihnc, _ := NewIndexHashedNodesCoordinator(arguments) - list2, err := ihnc.ComputeConsensusGroup([]byte("randomness"), 0, 0, 0) + leader, list2, err := ihnc.ComputeConsensusGroup([]byte("randomness"), 0, 0, 0) - require.Equal(t, list, list2) require.Nil(t, err) + require.Equal(t, list, list2) + require.Equal(t, list[0], leader) } func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10locksNoMemoization(t *testing.T) { @@ -463,10 +487,6 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10locksNoM waitingMap := make(map[uint32][]Validator) eligibleMap := createDummyNodesMap(nodesPerShard, 1, "eligible") shufflerArgs := &NodesShufflerArgs{ - NodesShard: nodesPerShard, - NodesMeta: nodesPerShard, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -474,7 +494,7 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10locksNoM nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() getCounter := int32(0) @@ -492,8 +512,20 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10locksNoM } arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: consensusGroupSize, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(consensusGroupSize), + MetachainConsensusGroupSize: 1, + } + }, + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(consensusGroupSize), + MetachainConsensusGroupSize: 1, + }, nil + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -519,12 +551,14 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10locksNoM miniBlocks := 10 var list2 []Validator + var leader Validator for i := 0; i < miniBlocks; i++ { for j := 0; j <= i; j++ { randomness := strconv.Itoa(j) - list2, err = ihnc.ComputeConsensusGroup([]byte(randomness), uint64(j), 0, 0) + leader, list2, err = ihnc.ComputeConsensusGroup([]byte(randomness), uint64(j), 0, 0) require.Nil(t, err) require.Equal(t, consensusGroupSize, len(list2)) + require.NotNil(t, leader) } } @@ -540,10 +574,6 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10BlocksMe waitingMap := make(map[uint32][]Validator) eligibleMap := createDummyNodesMap(nodesPerShard, 1, "eligible") shufflerArgs := &NodesShufflerArgs{ - NodesShard: nodesPerShard, - NodesMeta: nodesPerShard, - Hysteresis: 0, - Adaptivity: false, ShuffleBetweenShards: false, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -551,7 +581,7 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10BlocksMe nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() getCounter := 0 @@ -582,8 +612,20 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10BlocksMe } arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: consensusGroupSize, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(consensusGroupSize), + MetachainConsensusGroupSize: 1, + } + }, + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(consensusGroupSize), + MetachainConsensusGroupSize: 1, + }, nil + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -609,12 +651,14 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup400of400For10BlocksMe miniBlocks := 10 var list2 []Validator + var leader Validator for i := 0; i < miniBlocks; i++ { for j := 0; j <= i; j++ { randomness := strconv.Itoa(j) - list2, err = ihnc.ComputeConsensusGroup([]byte(randomness), uint64(j), 0, 0) + leader, list2, err = ihnc.ComputeConsensusGroup([]byte(randomness), uint64(j), 0, 0) require.Nil(t, err) require.Equal(t, consensusGroupSize, len(list2)) + require.NotNil(t, leader) } } @@ -641,10 +685,6 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup63of400TestEqualSameP eligibleMap := createDummyNodesMap(nodesPerShard, 1, "eligible") shufflerArgs := &NodesShufflerArgs{ - NodesShard: nodesPerShard, - NodesMeta: nodesPerShard, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -652,12 +692,18 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup63of400TestEqualSameP nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: consensusGroupSize, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(consensusGroupSize), + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -682,13 +728,15 @@ func TestIndexHashedNodesCoordinator_ComputeValidatorsGroup63of400TestEqualSameP repeatPerSampling := 100 list := make([][]Validator, repeatPerSampling) + var leader Validator for i := 0; i < nbDifferentSamplings; i++ { randomness := arguments.Hasher.Compute(strconv.Itoa(i)) fmt.Printf("starting selection with randomness: %s\n", hex.EncodeToString(randomness)) for j := 0; j < repeatPerSampling; j++ { - list[j], err = ihnc.ComputeConsensusGroup(randomness, 0, 0, 0) + leader, list[j], err = ihnc.ComputeConsensusGroup(randomness, 0, 0, 0) require.Nil(t, err) require.Equal(t, consensusGroupSize, len(list[j])) + require.NotNil(t, leader) } for j := 1; j < repeatPerSampling; j++ { @@ -705,10 +753,6 @@ func BenchmarkIndexHashedGroupSelector_ComputeValidatorsGroup21of400(b *testing. waitingMap := make(map[uint32][]Validator) eligibleMap := createDummyNodesMap(nodesPerShard, 1, "eligible") shufflerArgs := &NodesShufflerArgs{ - NodesShard: nodesPerShard, - NodesMeta: nodesPerShard, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -716,12 +760,18 @@ func BenchmarkIndexHashedGroupSelector_ComputeValidatorsGroup21of400(b *testing. nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(b, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: consensusGroupSize, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(consensusGroupSize), + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -745,9 +795,10 @@ func BenchmarkIndexHashedGroupSelector_ComputeValidatorsGroup21of400(b *testing. for i := 0; i < b.N; i++ { randomness := strconv.Itoa(i) - list2, _ := ihnc.ComputeConsensusGroup([]byte(randomness), 0, 0, 0) + leader, list2, _ := ihnc.ComputeConsensusGroup([]byte(randomness), 0, 0, 0) require.Equal(b, consensusGroupSize, len(list2)) + require.NotNil(b, leader) } } @@ -779,10 +830,6 @@ func BenchmarkIndexHashedNodesCoordinator_CopyMaps(b *testing.B) { func runBenchmark(consensusGroupCache Cacher, consensusGroupSize int, nodesMap map[uint32][]Validator, b *testing.B) { waitingMap := make(map[uint32][]Validator) shufflerArgs := &NodesShufflerArgs{ - NodesShard: 10, - NodesMeta: 10, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -790,12 +837,18 @@ func runBenchmark(consensusGroupCache Cacher, consensusGroupSize int, nodesMap m nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(b, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: consensusGroupSize, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(consensusGroupSize), + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, EpochStartNotifier: epochStartSubscriber, @@ -821,8 +874,9 @@ func runBenchmark(consensusGroupCache Cacher, consensusGroupSize int, nodesMap m missedBlocks := 1000 for j := 0; j < missedBlocks; j++ { randomness := strconv.Itoa(j) - list2, _ := ihnc.ComputeConsensusGroup([]byte(randomness), uint64(j), 0, 0) + leader, list2, _ := ihnc.ComputeConsensusGroup([]byte(randomness), uint64(j), 0, 0) require.Equal(b, consensusGroupSize, len(list2)) + require.NotNil(b, leader) } } } @@ -830,10 +884,6 @@ func runBenchmark(consensusGroupCache Cacher, consensusGroupSize int, nodesMap m func computeMemoryRequirements(consensusGroupCache Cacher, consensusGroupSize int, nodesMap map[uint32][]Validator, b *testing.B) { waitingMap := make(map[uint32][]Validator) shufflerArgs := &NodesShufflerArgs{ - NodesShard: 10, - NodesMeta: 10, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -841,12 +891,18 @@ func computeMemoryRequirements(consensusGroupCache Cacher, consensusGroupSize in nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(b, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: consensusGroupSize, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(consensusGroupSize), + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, EpochStartNotifier: epochStartSubscriber, @@ -873,8 +929,9 @@ func computeMemoryRequirements(consensusGroupCache Cacher, consensusGroupSize in missedBlocks := 1000 for i := 0; i < missedBlocks; i++ { randomness := strconv.Itoa(i) - list2, _ := ihnc.ComputeConsensusGroup([]byte(randomness), uint64(i), 0, 0) + leader, list2, _ := ihnc.ComputeConsensusGroup([]byte(randomness), uint64(i), 0, 0) require.Equal(b, consensusGroupSize, len(list2)) + require.NotNil(b, leader) } m2 := runtime.MemStats{} @@ -971,10 +1028,6 @@ func TestIndexHashedNodesCoordinator_GetValidatorWithPublicKeyShouldWork(t *test eligibleMap[0] = listShard0 eligibleMap[1] = listShard1 shufflerArgs := &NodesShufflerArgs{ - NodesShard: 10, - NodesMeta: 10, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -982,12 +1035,18 @@ func TestIndexHashedNodesCoordinator_GetValidatorWithPublicKeyShouldWork(t *test nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 1, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -1056,10 +1115,6 @@ func TestIndexHashedGroupSelector_GetAllEligibleValidatorsPublicKeys(t *testing. eligibleMap[shardZeroId] = listShard0 eligibleMap[shardOneId] = listShard1 shufflerArgs := &NodesShufflerArgs{ - NodesShard: 10, - NodesMeta: 10, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -1067,12 +1122,18 @@ func TestIndexHashedGroupSelector_GetAllEligibleValidatorsPublicKeys(t *testing. nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 1, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -1133,10 +1194,6 @@ func TestIndexHashedGroupSelector_GetAllWaitingValidatorsPublicKeys(t *testing.T waitingMap[shardOneId] = listShard1 shufflerArgs := &NodesShufflerArgs{ - NodesShard: 10, - NodesMeta: 10, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -1144,7 +1201,7 @@ func TestIndexHashedGroupSelector_GetAllWaitingValidatorsPublicKeys(t *testing.T nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() eligibleMap := make(map[uint32][]Validator) @@ -1152,8 +1209,14 @@ func TestIndexHashedGroupSelector_GetAllWaitingValidatorsPublicKeys(t *testing.T eligibleMap[shardZeroId] = []Validator{&validator{}} arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 1, - MetaConsensusGroupSize: 1, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + }, Marshalizer: &mock.MarshalizerMock{}, Hasher: &hashingMocks.HasherMock{}, Shuffler: nodeShuffler, @@ -1529,10 +1592,6 @@ func TestIndexHashedNodesCoordinator_EpochStart_EligibleSortedAscendingByIndex(t eligibleMap[core.MetachainShardId] = list shufflerArgs := &NodesShufflerArgs{ - NodesShard: 2, - NodesMeta: 2, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -1540,25 +1599,39 @@ func TestIndexHashedNodesCoordinator_EpochStart_EligibleSortedAscendingByIndex(t nodeShuffler, err := NewHashValidatorsShuffler(shufflerArgs) require.Nil(t, err) - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 1, - MetaConsensusGroupSize: 1, - Marshalizer: &mock.MarshalizerMock{}, - Hasher: &hashingMocks.HasherMock{}, - Shuffler: nodeShuffler, - EpochStartNotifier: epochStartSubscriber, - BootStorer: bootStorer, - NbShards: nbShards, - EligibleNodes: eligibleMap, - WaitingNodes: map[uint32][]Validator{}, - SelfPublicKey: []byte("test"), - ConsensusGroupCache: &mock.NodesCoordinatorCacheMock{}, - ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, - ChanStopNode: make(chan endProcess.ArgEndProcess), - NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + ChainParametersForEpochCalled: func(_ uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + ShardMinNumNodes: 2, + MetachainMinNumNodes: 2, + }, nil + }, + }, + Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Shuffler: nodeShuffler, + EpochStartNotifier: epochStartSubscriber, + BootStorer: bootStorer, + NbShards: nbShards, + EligibleNodes: eligibleMap, + WaitingNodes: map[uint32][]Validator{}, + SelfPublicKey: []byte("test"), + ConsensusGroupCache: &mock.NodesCoordinatorCacheMock{}, + ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, + ChanStopNode: make(chan endProcess.ArgEndProcess), + NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{ IsRefactorPeersMiniBlocksFlagEnabledField: true, }, @@ -1597,10 +1670,12 @@ func TestIndexHashedNodesCoordinator_GetConsensusValidatorsPublicKeysNotExisting require.Nil(t, err) var pKeys []string + var leader string randomness := []byte("randomness") - pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 1) + leader, pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 1) require.True(t, errors.Is(err, ErrEpochNodesConfigDoesNotExist)) require.Nil(t, pKeys) + require.Empty(t, leader) } func TestIndexHashedNodesCoordinator_GetConsensusValidatorsPublicKeysExistingEpoch(t *testing.T) { @@ -1613,11 +1688,13 @@ func TestIndexHashedNodesCoordinator_GetConsensusValidatorsPublicKeysExistingEpo shard0PubKeys := validatorsPubKeys(args.EligibleNodes[0]) var pKeys []string + var leader string randomness := []byte("randomness") - pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 0) + leader, pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 0) require.Nil(t, err) require.True(t, len(pKeys) > 0) require.True(t, isStringSubgroup(pKeys, shard0PubKeys)) + require.NotEmpty(t, leader) } func TestIndexHashedNodesCoordinator_GetValidatorsIndexes(t *testing.T) { @@ -1629,13 +1706,15 @@ func TestIndexHashedNodesCoordinator_GetValidatorsIndexes(t *testing.T) { randomness := []byte("randomness") var pKeys []string - pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 0) + var leader string + leader, pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 0) require.Nil(t, err) var indexes []uint64 indexes, err = ihnc.GetValidatorsIndexes(pKeys, 0) require.Nil(t, err) require.Equal(t, len(pKeys), len(indexes)) + require.NotEmpty(t, leader) } func TestIndexHashedNodesCoordinator_GetValidatorsIndexesInvalidPubKey(t *testing.T) { @@ -1647,8 +1726,10 @@ func TestIndexHashedNodesCoordinator_GetValidatorsIndexesInvalidPubKey(t *testin randomness := []byte("randomness") var pKeys []string - pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 0) + var leader string + leader, pKeys, err = ihnc.GetConsensusValidatorsPublicKeys(randomness, 0, 0, 0) require.Nil(t, err) + require.NotEmpty(t, leader) var indexes []uint64 pKeys[0] = "dummy" @@ -1783,6 +1864,39 @@ func TestIndexHashedNodesCoordinator_GetConsensusWhitelistedNodesEpoch1(t *testi } } +func TestIndexHashedNodesCoordinator_GetAllEligibleValidatorsPublicKeysForShard(t *testing.T) { + t.Parallel() + + t.Run("missing nodes config should error", func(t *testing.T) { + t.Parallel() + + arguments := createArguments() + arguments.ValidatorInfoCacher = dataPool.NewCurrentEpochValidatorInfoPool() + ihnc, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + validators, err := ihnc.GetAllEligibleValidatorsPublicKeysForShard(100, 0) + require.True(t, errors.Is(err, ErrEpochNodesConfigDoesNotExist)) + require.Nil(t, validators) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + arguments := createArguments() + arguments.ValidatorInfoCacher = dataPool.NewCurrentEpochValidatorInfoPool() + ihnc, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + expectedValidators := make([]string, 0, len(arguments.EligibleNodes[0])) + for _, val := range arguments.EligibleNodes[0] { + expectedValidators = append(expectedValidators, string(val.PubKey())) + } + validators, err := ihnc.GetAllEligibleValidatorsPublicKeysForShard(0, 0) + require.NoError(t, err) + require.Equal(t, expectedValidators, validators) + }) +} + func TestIndexHashedNodesCoordinator_GetConsensusWhitelistedNodesAfterRevertToEpoch(t *testing.T) { t.Parallel() @@ -1850,15 +1964,43 @@ func TestIndexHashedNodesCoordinator_GetConsensusWhitelistedNodesAfterRevertToEp func TestIndexHashedNodesCoordinator_ConsensusGroupSize(t *testing.T) { t.Parallel() + testEpoch := uint32(37) + shardConsensusGroupSize, metaConsensusGroupSize := 1, 1 arguments := createArguments() + arguments.Epoch = testEpoch - 1 + numTimesChainParametersForEpochWasCalled := 0 + arguments.ChainParametersHandler = &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: uint32(shardConsensusGroupSize), + MetachainConsensusGroupSize: uint32(metaConsensusGroupSize), + } + }, + ChainParametersForEpochCalled: func(epoch uint32) (config.ChainParametersByEpochConfig, error) { + if numTimesChainParametersForEpochWasCalled == 0 { + require.Equal(t, testEpoch-1, epoch) + } else { + require.Equal(t, testEpoch, epoch) + } + numTimesChainParametersForEpochWasCalled++ + + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + }, nil + }, + } ihnc, err := NewIndexHashedNodesCoordinator(arguments) require.Nil(t, err) - consensusSizeShard := ihnc.ConsensusGroupSize(0) - consensusSizeMeta := ihnc.ConsensusGroupSize(core.MetachainShardId) + consensusSizeShard := ihnc.ConsensusGroupSizeForShardAndEpoch(0, testEpoch) + consensusSizeMeta := ihnc.ConsensusGroupSizeForShardAndEpoch(core.MetachainShardId, testEpoch) + + require.Equal(t, shardConsensusGroupSize, consensusSizeShard) + require.Equal(t, metaConsensusGroupSize, consensusSizeMeta) - require.Equal(t, arguments.ShardConsensusGroupSize, consensusSizeShard) - require.Equal(t, arguments.MetaConsensusGroupSize, consensusSizeMeta) + // consensus group size from chain parameters should have been called once from the constructor, once for shard and once for meta + require.Equal(t, 3, numTimesChainParametersForEpochWasCalled) } func TestIndexHashedNodesCoordinator_GetNumTotalEligible(t *testing.T) { @@ -2516,14 +2658,10 @@ func TestIndexHashedGroupSelector_GetWaitingEpochsLeftForPublicKey(t *testing.T) t.Run("missing nodes config for current epoch should error ", func(t *testing.T) { t.Parallel() - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() shufflerArgs := &NodesShufflerArgs{ - NodesShard: 10, - NodesMeta: 10, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -2532,15 +2670,21 @@ func TestIndexHashedGroupSelector_GetWaitingEpochsLeftForPublicKey(t *testing.T) require.Nil(t, err) arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 1, - MetaConsensusGroupSize: 1, - Marshalizer: &mock.MarshalizerMock{}, - Hasher: &hashingMocks.HasherMock{}, - Shuffler: nodeShuffler, - EpochStartNotifier: epochStartSubscriber, - BootStorer: bootStorer, - ShardIDAsObserver: 0, - NbShards: 2, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + }, + Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Shuffler: nodeShuffler, + EpochStartNotifier: epochStartSubscriber, + BootStorer: bootStorer, + ShardIDAsObserver: 0, + NbShards: 2, EligibleNodes: map[uint32][]Validator{ core.MetachainShardId: {newValidatorMock([]byte("pk"), 1, 0)}, }, @@ -2584,7 +2728,7 @@ func TestIndexHashedGroupSelector_GetWaitingEpochsLeftForPublicKey(t *testing.T) waitingMap[core.MetachainShardId] = listMeta waitingMap[shardZeroId] = listShard0 - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() eligibleMap := make(map[uint32][]Validator) @@ -2592,10 +2736,6 @@ func TestIndexHashedGroupSelector_GetWaitingEpochsLeftForPublicKey(t *testing.T) eligibleMap[shardZeroId] = []Validator{&validator{}} shufflerArgs := &NodesShufflerArgs{ - NodesShard: 10, - NodesMeta: 10, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -2604,24 +2744,30 @@ func TestIndexHashedGroupSelector_GetWaitingEpochsLeftForPublicKey(t *testing.T) require.Nil(t, err) arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 1, - MetaConsensusGroupSize: 1, - Marshalizer: &mock.MarshalizerMock{}, - Hasher: &hashingMocks.HasherMock{}, - Shuffler: nodeShuffler, - EpochStartNotifier: epochStartSubscriber, - BootStorer: bootStorer, - ShardIDAsObserver: shardZeroId, - NbShards: 2, - EligibleNodes: eligibleMap, - WaitingNodes: waitingMap, - SelfPublicKey: []byte("key"), - ConsensusGroupCache: &mock.NodesCoordinatorCacheMock{}, - ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, - ChanStopNode: make(chan endProcess.ArgEndProcess), - NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, - EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, - ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + }, + Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Shuffler: nodeShuffler, + EpochStartNotifier: epochStartSubscriber, + BootStorer: bootStorer, + ShardIDAsObserver: shardZeroId, + NbShards: 2, + EligibleNodes: eligibleMap, + WaitingNodes: waitingMap, + SelfPublicKey: []byte("key"), + ConsensusGroupCache: &mock.NodesCoordinatorCacheMock{}, + ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, + ChanStopNode: make(chan endProcess.ArgEndProcess), + NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, + EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, + ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, GenesisNodesSetupHandler: &mock.NodesSetupMock{ MinShardHysteresisNodesCalled: func() uint32 { return 0 @@ -2669,7 +2815,7 @@ func TestIndexHashedGroupSelector_GetWaitingEpochsLeftForPublicKey(t *testing.T) waitingMap[core.MetachainShardId] = listMeta waitingMap[shardZeroId] = listShard0 - epochStartSubscriber := &mock.EpochStartNotifierStub{} + epochStartSubscriber := &testscommonConsensus.EpochStartNotifierStub{} bootStorer := genericMocks.NewStorerMock() eligibleMap := make(map[uint32][]Validator) @@ -2677,10 +2823,6 @@ func TestIndexHashedGroupSelector_GetWaitingEpochsLeftForPublicKey(t *testing.T) eligibleMap[shardZeroId] = []Validator{&validator{}} shufflerArgs := &NodesShufflerArgs{ - NodesShard: 10, - NodesMeta: 10, - Hysteresis: hysteresis, - Adaptivity: adaptivity, ShuffleBetweenShards: shuffleBetweenShards, MaxNodesEnableConfig: nil, EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, @@ -2689,24 +2831,30 @@ func TestIndexHashedGroupSelector_GetWaitingEpochsLeftForPublicKey(t *testing.T) require.Nil(t, err) arguments := ArgNodesCoordinator{ - ShardConsensusGroupSize: 1, - MetaConsensusGroupSize: 1, - Marshalizer: &mock.MarshalizerMock{}, - Hasher: &hashingMocks.HasherMock{}, - Shuffler: nodeShuffler, - EpochStartNotifier: epochStartSubscriber, - BootStorer: bootStorer, - ShardIDAsObserver: shardZeroId, - NbShards: 2, - EligibleNodes: eligibleMap, - WaitingNodes: waitingMap, - SelfPublicKey: []byte("key"), - ConsensusGroupCache: &mock.NodesCoordinatorCacheMock{}, - ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, - ChanStopNode: make(chan endProcess.ArgEndProcess), - NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, - EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, - ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, + ChainParametersHandler: &chainParameters.ChainParametersHandlerStub{ + CurrentChainParametersCalled: func() config.ChainParametersByEpochConfig { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: 1, + MetachainConsensusGroupSize: 1, + } + }, + }, + Marshalizer: &mock.MarshalizerMock{}, + Hasher: &hashingMocks.HasherMock{}, + Shuffler: nodeShuffler, + EpochStartNotifier: epochStartSubscriber, + BootStorer: bootStorer, + ShardIDAsObserver: shardZeroId, + NbShards: 2, + EligibleNodes: eligibleMap, + WaitingNodes: waitingMap, + SelfPublicKey: []byte("key"), + ConsensusGroupCache: &mock.NodesCoordinatorCacheMock{}, + ShuffledOutHandler: &mock.ShuffledOutHandlerStub{}, + ChanStopNode: make(chan endProcess.ArgEndProcess), + NodeTypeProvider: &nodeTypeProviderMock.NodeTypeProviderStub{}, + EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, + ValidatorInfoCacher: &vic.ValidatorInfoCacherStub{}, GenesisNodesSetupHandler: &mock.NodesSetupMock{ MinShardHysteresisNodesCalled: func() uint32 { return 2 @@ -2761,3 +2909,324 @@ func TestIndexHashedGroupSelector_GetWaitingEpochsLeftForPublicKey(t *testing.T) require.Equal(t, uint32(3), epochsLeft) }) } + +func TestNodesCoordinator_CustomConsensusGroupSize(t *testing.T) { + arguments := createArguments() + numShards := uint32(2) + nodesPerShard := uint32(3) + eligibleMap := createDummyNodesMap(nodesPerShard, numShards, "eligible") + waitingMap := createDummyNodesMap(0, numShards, "waiting") + arguments.EligibleNodes = eligibleMap + arguments.WaitingNodes = waitingMap + arguments.ValidatorInfoCacher = dataPool.NewCurrentEpochValidatorInfoPool() + + consensusParams := []struct { + enableEpoch uint32 + shardCnsSize uint32 + metaCnsSize uint32 + shardMinNodes uint32 + metaMinNodes uint32 + }{ + { + enableEpoch: 9, + shardCnsSize: 3, + shardMinNodes: 3, + metaCnsSize: 3, + metaMinNodes: 3, + }, + { + enableEpoch: 6, + shardCnsSize: 3, + shardMinNodes: 3, + metaCnsSize: 2, + metaMinNodes: 2, + }, + { + enableEpoch: 3, + shardCnsSize: 3, + shardMinNodes: 3, + metaCnsSize: 3, + metaMinNodes: 3, + }, + { + enableEpoch: 0, + shardCnsSize: 2, + shardMinNodes: 2, + metaCnsSize: 3, + metaMinNodes: 3, + }, + } + arguments.ChainParametersHandler = &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(epoch uint32) (config.ChainParametersByEpochConfig, error) { + for _, cfg := range consensusParams { + if epoch >= cfg.enableEpoch { + return config.ChainParametersByEpochConfig{ + ShardConsensusGroupSize: cfg.shardCnsSize, + ShardMinNumNodes: cfg.shardMinNodes, + MetachainConsensusGroupSize: cfg.metaCnsSize, + MetachainMinNumNodes: cfg.metaMinNodes, + }, nil + } + } + + return config.ChainParametersByEpochConfig{}, errors.New("wrong test setup") + }, + } + + shufflerArgs := &NodesShufflerArgs{ + ShuffleBetweenShards: shuffleBetweenShards, + EnableEpochsHandler: &mock.EnableEpochsHandlerMock{}, + MaxNodesEnableConfig: []config.MaxNodesChangeConfig{ + {EpochEnable: 0, MaxNumNodes: nodesPerShard * (numShards + 1), NodesToShufflePerShard: 2}, + {EpochEnable: 3, MaxNumNodes: nodesPerShard * (numShards + 1), NodesToShufflePerShard: 3}, + }, + } + arguments.Shuffler, _ = NewHashValidatorsShuffler(shufflerArgs) + + ihnc, _ := NewIndexHashedNodesCoordinator(arguments) + require.NotNil(t, ihnc) + + numEpochsToCheck := uint32(100) + checksCounter := 0 + for ep := uint32(0); ep < numEpochsToCheck; ep++ { + for _, cfg := range consensusParams { + if ep >= cfg.enableEpoch { + changeEpochAndTestNewConsensusSizes(&consensusSizeChangeTestArgs{ + t: t, + ihnc: ihnc, + epoch: ep, + expectedShardMinNodes: cfg.shardMinNodes, + expectedMetaMinNodes: cfg.metaMinNodes, + }) + checksCounter++ + break + } + } + } + require.Equal(t, numEpochsToCheck, uint32(checksCounter)) +} + +func TestIndexHashedNodesCoordinator_cacheConsensusGroup(t *testing.T) { + t.Parallel() + + maxNumValuesCache := 3 + key := []byte("key") + + leader := &validator{ + pubKey: []byte("leader"), + chances: 10, + index: 20, + } + validator1 := &validator{ + pubKey: []byte("validator1"), + chances: 10, + index: 20, + } + + t.Run("adding a key should work", func(t *testing.T) { + t.Parallel() + + arguments := createArguments() + + arguments.ConsensusGroupCache, _ = cache.NewLRUCache(maxNumValuesCache) + nodesCoordinator, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + consensusGroup := []Validator{leader, validator1} + expectedData := &savedConsensusGroup{ + leader: leader, + consensusGroup: consensusGroup, + } + + nodesCoordinator.cacheConsensusGroup(key, consensusGroup, leader) + value := nodesCoordinator.searchConsensusForKey(key) + + require.NotNil(t, value) + require.Equal(t, expectedData, value) + }) + + t.Run("adding a key twice should overwrite the value", func(t *testing.T) { + t.Parallel() + + arguments := createArguments() + + arguments.ConsensusGroupCache, _ = cache.NewLRUCache(maxNumValuesCache) + nodesCoordinator, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + cg1 := []Validator{leader, validator1} + cg2 := []Validator{leader} + expectedData := &savedConsensusGroup{ + leader: leader, + consensusGroup: cg2, + } + + nodesCoordinator.cacheConsensusGroup(key, cg1, leader) + nodesCoordinator.cacheConsensusGroup(key, cg2, leader) + value := nodesCoordinator.searchConsensusForKey(key) + require.NotNil(t, value) + require.Equal(t, expectedData, value) + }) + + t.Run("adding more keys than the cache size should remove the oldest key", func(t *testing.T) { + t.Parallel() + + arguments := createArguments() + + key1 := []byte("key1") + key2 := []byte("key2") + key3 := []byte("key3") + key4 := []byte("key4") + + cg1 := []Validator{leader, validator1} + cg2 := []Validator{leader} + cg3 := []Validator{validator1} + cg4 := []Validator{leader, validator1, validator1} + + arguments.ConsensusGroupCache, _ = cache.NewLRUCache(maxNumValuesCache) + nodesCoordinator, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + nodesCoordinator.cacheConsensusGroup(key1, cg1, leader) + nodesCoordinator.cacheConsensusGroup(key2, cg2, leader) + nodesCoordinator.cacheConsensusGroup(key3, cg3, leader) + nodesCoordinator.cacheConsensusGroup(key4, cg4, leader) + + value := nodesCoordinator.searchConsensusForKey(key1) + require.Nil(t, value) + + value = nodesCoordinator.searchConsensusForKey(key2) + require.Equal(t, cg2, value.consensusGroup) + + value = nodesCoordinator.searchConsensusForKey(key3) + require.Equal(t, cg3, value.consensusGroup) + + value = nodesCoordinator.searchConsensusForKey(key4) + require.Equal(t, cg4, value.consensusGroup) + }) +} + +func TestIndexHashedNodesCoordinator_selectLeaderAndConsensusGroup(t *testing.T) { + t.Parallel() + + validator1 := &validator{pubKey: []byte("validator1")} + validator2 := &validator{pubKey: []byte("validator2")} + validator3 := &validator{pubKey: []byte("validator3")} + validator4 := &validator{pubKey: []byte("validator4")} + + randomness := []byte("randomness") + epoch := uint32(1) + + eligibleList := []Validator{validator1, validator2, validator3, validator4} + consensusSize := len(eligibleList) + expectedError := errors.New("expected error") + selectFunc := func(randSeed []byte, sampleSize uint32) ([]uint32, error) { + if len(eligibleList) < int(sampleSize) { + return nil, expectedError + } + + result := make([]uint32, sampleSize) + for i := 0; i < int(sampleSize); i++ { + // reverse order from eligible list + result[i] = uint32(len(eligibleList) - 1 - i) + } + + return result, nil + } + expectedConsensusFixedOrder := []Validator{validator1, validator2, validator3, validator4} + expectedConsensusNotFixedOrder := []Validator{validator4, validator3, validator2, validator1} + expectedLeader := validator4 + + t.Run("with fixed ordering enabled, data not cached", func(t *testing.T) { + t.Parallel() + + arguments := createArguments() + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return true + }, + } + + ihnc, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + selector := &nodesCoordinatorMocks.RandomSelectorMock{ + SelectCalled: selectFunc, + } + + leader, cg, err := ihnc.selectLeaderAndConsensusGroup(selector, randomness, eligibleList, consensusSize, epoch) + require.Nil(t, err) + require.Equal(t, validator4, leader) + require.Equal(t, expectedLeader, leader) + require.Equal(t, expectedConsensusFixedOrder, cg) + }) + t.Run("with fixed ordering disabled, data not cached", func(t *testing.T) { + t.Parallel() + arguments := createArguments() + arguments.EnableEpochsHandler = &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return false + }, + } + + ihnc, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + selector := &nodesCoordinatorMocks.RandomSelectorMock{ + SelectCalled: selectFunc, + } + + leader, cg, err := ihnc.selectLeaderAndConsensusGroup(selector, randomness, eligibleList, consensusSize, epoch) + require.Nil(t, err) + require.Equal(t, expectedLeader, leader) + require.Equal(t, expectedConsensusNotFixedOrder, cg) + }) +} + +func TestIndexHashedNodesCoordinator_GetCachedEpochs(t *testing.T) { + t.Parallel() + + arguments := createArguments() + ihnc, err := NewIndexHashedNodesCoordinator(arguments) + require.Nil(t, err) + + cachedEpochs := ihnc.GetCachedEpochs() + require.Equal(t, 1, len(cachedEpochs)) + + // add new epoch + ihnc.AddDummyEpoch(1) + cachedEpochs = ihnc.GetCachedEpochs() + require.Equal(t, 2, len(cachedEpochs)) + + // add new epoch + ihnc.AddDummyEpoch(2) + cachedEpochs = ihnc.GetCachedEpochs() + require.Equal(t, 3, len(cachedEpochs)) +} + +type consensusSizeChangeTestArgs struct { + t *testing.T + ihnc *indexHashedNodesCoordinator + epoch uint32 + expectedShardMinNodes uint32 + expectedMetaMinNodes uint32 +} + +func changeEpochAndTestNewConsensusSizes(args *consensusSizeChangeTestArgs) { + header := &block.MetaBlock{ + PrevRandSeed: []byte("rand seed"), + EpochStart: block.EpochStart{LastFinalizedHeaders: []block.EpochStartShardData{{}}}, + } + + header.Epoch = args.epoch + epochForPrevConfig := uint32(0) + if args.epoch > 0 { + epochForPrevConfig = args.epoch - 1 + } + args.ihnc.nodesConfig[args.epoch] = args.ihnc.nodesConfig[epochForPrevConfig] + body := createBlockBodyFromNodesCoordinator(args.ihnc, args.epoch, args.ihnc.validatorInfoCacher) + args.ihnc.EpochStartPrepare(header, body) + args.ihnc.EpochStartAction(header) + require.Len(args.t, args.ihnc.nodesConfig[args.epoch].eligibleMap[0], int(args.expectedShardMinNodes)) + require.Len(args.t, args.ihnc.nodesConfig[args.epoch].eligibleMap[common.MetachainShardId], int(args.expectedMetaMinNodes)) +} diff --git a/sharding/nodesCoordinator/interface.go b/sharding/nodesCoordinator/interface.go index 5e2d5564a5c..2d37beed268 100644 --- a/sharding/nodesCoordinator/interface.go +++ b/sharding/nodesCoordinator/interface.go @@ -5,6 +5,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" vmcommon "github.com/multiversx/mx-chain-vm-common-go" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/epochStart" "github.com/multiversx/mx-chain-go/state" ) @@ -22,16 +23,17 @@ type Validator interface { type NodesCoordinator interface { NodesCoordinatorHelper PublicKeysSelector - ComputeConsensusGroup(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []Validator, err error) + ComputeConsensusGroup(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader Validator, validatorsGroup []Validator, err error) GetValidatorWithPublicKey(publicKey []byte) (validator Validator, shardId uint32, err error) LoadState(key []byte) error GetSavedStateKey() []byte ShardIdForEpoch(epoch uint32) (uint32, error) ShuffleOutForEpoch(_ uint32) GetConsensusWhitelistedNodes(epoch uint32) (map[string]struct{}, error) - ConsensusGroupSize(uint32) int + ConsensusGroupSizeForShardAndEpoch(uint32, uint32) int GetNumTotalEligible() uint64 GetWaitingEpochsLeftForPublicKey(publicKey []byte) (uint32, error) + GetCachedEpochs() map[uint32]struct{} IsInterfaceNil() bool } @@ -46,17 +48,17 @@ type EpochStartEventNotifier interface { type PublicKeysSelector interface { GetValidatorsIndexes(publicKeys []string, epoch uint32) ([]uint64, error) GetAllEligibleValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) + GetAllEligibleValidatorsPublicKeysForShard(epoch uint32, shardID uint32) ([]string, error) GetAllWaitingValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) GetAllLeavingValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) GetAllShuffledOutValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) GetShuffledOutToAuctionValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) - GetConsensusValidatorsPublicKeys(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) + GetConsensusValidatorsPublicKeys(randomness []byte, round uint64, shardId uint32, epoch uint32) (string, []string, error) GetOwnPublicKey() []byte } // NodesShuffler provides shuffling functionality for nodes type NodesShuffler interface { - UpdateParams(numNodesShard uint32, numNodesMeta uint32, hysteresis float32, adaptivity bool) UpdateNodeLists(args ArgsUpdateNodes) (*ResUpdateNodes, error) IsInterfaceNil() bool } @@ -178,3 +180,11 @@ type EpochNotifier interface { CheckEpoch(header data.HeaderHandler) IsInterfaceNil() bool } + +// ChainParametersHandler defines the actions that need to be done by a component that can handle chain parameters +type ChainParametersHandler interface { + CurrentChainParameters() config.ChainParametersByEpochConfig + AllChainParameters() []config.ChainParametersByEpochConfig + ChainParametersForEpoch(epoch uint32) (config.ChainParametersByEpochConfig, error) + IsInterfaceNil() bool +} diff --git a/sharding/nodesCoordinator/shardingArgs.go b/sharding/nodesCoordinator/shardingArgs.go index 67c542952d7..02788e0e0a8 100644 --- a/sharding/nodesCoordinator/shardingArgs.go +++ b/sharding/nodesCoordinator/shardingArgs.go @@ -11,8 +11,7 @@ import ( // ArgNodesCoordinator holds all dependencies required by the nodes coordinator in order to create new instances type ArgNodesCoordinator struct { - ShardConsensusGroupSize int - MetaConsensusGroupSize int + ChainParametersHandler ChainParametersHandler Marshalizer marshal.Marshalizer Hasher hashing.Hasher Shuffler NodesShuffler diff --git a/sharding/nodesSetup.go b/sharding/nodesSetup.go index 7a1a94691c6..32f9b1dbc92 100644 --- a/sharding/nodesSetup.go +++ b/sharding/nodesSetup.go @@ -6,6 +6,8 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" + + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" ) @@ -58,20 +60,11 @@ func (ni *nodeInfo) IsInterfaceNil() bool { // NodesSetup hold data for decoded data from json file type NodesSetup struct { - StartTime int64 `json:"startTime"` - RoundDuration uint64 `json:"roundDuration"` - ConsensusGroupSize uint32 `json:"consensusGroupSize"` - MinNodesPerShard uint32 `json:"minNodesPerShard"` - - MetaChainConsensusGroupSize uint32 `json:"metaChainConsensusGroupSize"` - MetaChainMinNodes uint32 `json:"metaChainMinNodes"` - Hysteresis float32 `json:"hysteresis"` - Adaptivity bool `json:"adaptivity"` - - InitialNodes []*InitialNode `json:"initialNodes"` + NodesSetupDTO + genesisChainParameters config.ChainParametersByEpochConfig genesisMaxNumShards uint32 - nrOfShards uint32 + numberOfShards uint32 nrOfNodes uint32 nrOfMetaChainNodes uint32 eligible map[uint32][]nodesCoordinator.GenesisNodeInfoHandler @@ -82,31 +75,54 @@ type NodesSetup struct { // NewNodesSetup creates a new decoded nodes structure from json config file func NewNodesSetup( - nodesFilePath string, + nodesConfig config.NodesConfig, + chainParametersProvider ChainParametersHandler, addressPubkeyConverter core.PubkeyConverter, validatorPubkeyConverter core.PubkeyConverter, genesisMaxNumShards uint32, ) (*NodesSetup, error) { - if check.IfNil(addressPubkeyConverter) { return nil, fmt.Errorf("%w for addressPubkeyConverter", ErrNilPubkeyConverter) } if check.IfNil(validatorPubkeyConverter) { return nil, fmt.Errorf("%w for validatorPubkeyConverter", ErrNilPubkeyConverter) } + if check.IfNil(chainParametersProvider) { + return nil, ErrNilChainParametersProvider + } if genesisMaxNumShards < 1 { return nil, fmt.Errorf("%w for genesisMaxNumShards", ErrInvalidMaximumNumberOfShards) } + genesisParams, err := chainParametersProvider.ChainParametersForEpoch(0) + if err != nil { + return nil, fmt.Errorf("NewNodesSetup: %w while fetching parameters for epoch 0", err) + } + nodes := &NodesSetup{ addressPubkeyConverter: addressPubkeyConverter, validatorPubkeyConverter: validatorPubkeyConverter, genesisMaxNumShards: genesisMaxNumShards, + genesisChainParameters: genesisParams, } - err := core.LoadJsonFile(nodes, nodesFilePath) - if err != nil { - return nil, err + initialNodes := make([]*InitialNode, 0, len(nodesConfig.InitialNodes)) + for _, item := range nodesConfig.InitialNodes { + initialNodes = append(initialNodes, &InitialNode{ + PubKey: item.PubKey, + Address: item.Address, + InitialRating: item.InitialRating, + nodeInfo: nodeInfo{}, + }) + } + + genesisChainParameters := nodes.genesisChainParameters + nodes.NodesSetupDTO = NodesSetupDTO{ + StartTime: nodesConfig.StartTime, + RoundDuration: genesisChainParameters.RoundDuration, + Hysteresis: genesisChainParameters.Hysteresis, + Adaptivity: genesisChainParameters.Adaptivity, + InitialNodes: initialNodes, } err = nodes.processConfig() @@ -160,34 +176,31 @@ func (ns *NodesSetup) processConfig() error { ns.nrOfNodes++ } - if ns.ConsensusGroupSize < 1 { + if ns.genesisChainParameters.ShardConsensusGroupSize < 1 { return ErrNegativeOrZeroConsensusGroupSize } - if ns.MinNodesPerShard < ns.ConsensusGroupSize { + if ns.genesisChainParameters.ShardMinNumNodes < ns.genesisChainParameters.ShardConsensusGroupSize { return ErrMinNodesPerShardSmallerThanConsensusSize } - if ns.nrOfNodes < ns.MinNodesPerShard { + if ns.nrOfNodes < ns.genesisChainParameters.ShardMinNumNodes { return ErrNodesSizeSmallerThanMinNoOfNodes } - - if ns.MetaChainConsensusGroupSize < 1 { + if ns.genesisChainParameters.MetachainMinNumNodes < 1 { return ErrNegativeOrZeroConsensusGroupSize } - if ns.MetaChainMinNodes < ns.MetaChainConsensusGroupSize { + if ns.genesisChainParameters.MetachainMinNumNodes < ns.genesisChainParameters.MetachainConsensusGroupSize { return ErrMinNodesPerShardSmallerThanConsensusSize } - - totalMinNodes := ns.MetaChainMinNodes + ns.MinNodesPerShard + totalMinNodes := ns.genesisChainParameters.MetachainMinNumNodes + ns.genesisChainParameters.ShardMinNumNodes if ns.nrOfNodes < totalMinNodes { return ErrNodesSizeSmallerThanMinNoOfNodes } - return nil } func (ns *NodesSetup) processMetaChainAssigment() { ns.nrOfMetaChainNodes = 0 - for id := uint32(0); id < ns.MetaChainMinNodes; id++ { + for id := uint32(0); id < ns.genesisChainParameters.MetachainMinNumNodes; id++ { if ns.InitialNodes[id].pubKey != nil { ns.InitialNodes[id].assignedShard = core.MetachainShardId ns.InitialNodes[id].eligible = true @@ -195,13 +208,13 @@ func (ns *NodesSetup) processMetaChainAssigment() { } } - hystMeta := uint32(float32(ns.MetaChainMinNodes) * ns.Hysteresis) - hystShard := uint32(float32(ns.MinNodesPerShard) * ns.Hysteresis) + hystMeta := uint32(float32(ns.genesisChainParameters.MetachainMinNumNodes) * ns.Hysteresis) + hystShard := uint32(float32(ns.genesisChainParameters.ShardMinNumNodes) * ns.Hysteresis) - ns.nrOfShards = (ns.nrOfNodes - ns.nrOfMetaChainNodes - hystMeta) / (ns.MinNodesPerShard + hystShard) + ns.numberOfShards = (ns.nrOfNodes - ns.nrOfMetaChainNodes - hystMeta) / (ns.genesisChainParameters.ShardMinNumNodes + hystShard) - if ns.nrOfShards > ns.genesisMaxNumShards { - ns.nrOfShards = ns.genesisMaxNumShards + if ns.numberOfShards > ns.genesisMaxNumShards { + ns.numberOfShards = ns.genesisMaxNumShards } } @@ -209,8 +222,8 @@ func (ns *NodesSetup) processShardAssignment() { // initial implementation - as there is no other info than public key, we allocate first nodes in FIFO order to shards currentShard := uint32(0) countSetNodes := ns.nrOfMetaChainNodes - for ; currentShard < ns.nrOfShards; currentShard++ { - for id := countSetNodes; id < ns.nrOfMetaChainNodes+(currentShard+1)*ns.MinNodesPerShard; id++ { + for ; currentShard < ns.numberOfShards; currentShard++ { + for id := countSetNodes; id < ns.nrOfMetaChainNodes+(currentShard+1)*ns.genesisChainParameters.ShardMinNumNodes; id++ { // consider only nodes with valid public key if ns.InitialNodes[id].pubKey != nil { ns.InitialNodes[id].assignedShard = currentShard @@ -223,8 +236,8 @@ func (ns *NodesSetup) processShardAssignment() { // allocate the rest to waiting lists currentShard = 0 for i := countSetNodes; i < ns.nrOfNodes; i++ { - currentShard = (currentShard + 1) % (ns.nrOfShards + 1) - if currentShard == ns.nrOfShards { + currentShard = (currentShard + 1) % (ns.numberOfShards + 1) + if currentShard == ns.numberOfShards { currentShard = core.MetachainShardId } @@ -236,7 +249,7 @@ func (ns *NodesSetup) processShardAssignment() { } func (ns *NodesSetup) createInitialNodesInfo() { - nrOfShardAndMeta := ns.nrOfShards + 1 + nrOfShardAndMeta := ns.numberOfShards + 1 ns.eligible = make(map[uint32][]nodesCoordinator.GenesisNodeInfoHandler, nrOfShardAndMeta) ns.waiting = make(map[uint32][]nodesCoordinator.GenesisNodeInfoHandler, nrOfShardAndMeta) @@ -320,22 +333,22 @@ func (ns *NodesSetup) InitialNodesInfoForShard(shardId uint32) ([]nodesCoordinat // NumberOfShards returns the calculated number of shards func (ns *NodesSetup) NumberOfShards() uint32 { - return ns.nrOfShards + return ns.numberOfShards } // MinNumberOfNodes returns the minimum number of nodes func (ns *NodesSetup) MinNumberOfNodes() uint32 { - return ns.nrOfShards*ns.MinNodesPerShard + ns.MetaChainMinNodes + return ns.numberOfShards*ns.genesisChainParameters.ShardMinNumNodes + ns.genesisChainParameters.MetachainMinNumNodes } // MinShardHysteresisNodes returns the minimum number of hysteresis nodes per shard func (ns *NodesSetup) MinShardHysteresisNodes() uint32 { - return uint32(float32(ns.MinNodesPerShard) * ns.Hysteresis) + return uint32(float32(ns.genesisChainParameters.ShardMinNumNodes) * ns.Hysteresis) } // MinMetaHysteresisNodes returns the minimum number of hysteresis nodes in metachain func (ns *NodesSetup) MinMetaHysteresisNodes() uint32 { - return uint32(float32(ns.MetaChainMinNodes) * ns.Hysteresis) + return uint32(float32(ns.genesisChainParameters.MetachainMinNumNodes) * ns.Hysteresis) } // MinNumberOfNodesWithHysteresis returns the minimum number of nodes with hysteresis @@ -344,17 +357,17 @@ func (ns *NodesSetup) MinNumberOfNodesWithHysteresis() uint32 { hystNodesShard := ns.MinShardHysteresisNodes() minNumberOfNodes := ns.MinNumberOfNodes() - return minNumberOfNodes + hystNodesMeta + ns.nrOfShards*hystNodesShard + return minNumberOfNodes + hystNodesMeta + ns.numberOfShards*hystNodesShard } // MinNumberOfShardNodes returns the minimum number of nodes per shard func (ns *NodesSetup) MinNumberOfShardNodes() uint32 { - return ns.MinNodesPerShard + return ns.genesisChainParameters.ShardMinNumNodes } // MinNumberOfMetaNodes returns the minimum number of nodes in metachain func (ns *NodesSetup) MinNumberOfMetaNodes() uint32 { - return ns.MetaChainMinNodes + return ns.genesisChainParameters.MetachainMinNumNodes } // GetHysteresis returns the hysteresis value @@ -389,12 +402,30 @@ func (ns *NodesSetup) GetRoundDuration() uint64 { // GetShardConsensusGroupSize returns the shard consensus group size func (ns *NodesSetup) GetShardConsensusGroupSize() uint32 { - return ns.ConsensusGroupSize + return ns.genesisChainParameters.ShardConsensusGroupSize } // GetMetaConsensusGroupSize returns the metachain consensus group size func (ns *NodesSetup) GetMetaConsensusGroupSize() uint32 { - return ns.MetaChainConsensusGroupSize + return ns.genesisChainParameters.MetachainConsensusGroupSize +} + +// ExportNodesConfig will create and return the nodes' configuration +func (ns *NodesSetup) ExportNodesConfig() config.NodesConfig { + initialNodes := ns.InitialNodes + initialNodesToExport := make([]*config.InitialNodeConfig, 0, len(initialNodes)) + for _, item := range initialNodes { + initialNodesToExport = append(initialNodesToExport, &config.InitialNodeConfig{ + PubKey: item.PubKey, + Address: item.Address, + InitialRating: item.InitialRating, + }) + } + + return config.NodesConfig{ + StartTime: ns.StartTime, + InitialNodes: initialNodesToExport, + } } // IsInterfaceNil returns true if underlying object is nil diff --git a/sharding/nodesSetup_test.go b/sharding/nodesSetup_test.go index ca8d3ce479b..f5e6de19a8b 100644 --- a/sharding/nodesSetup_test.go +++ b/sharding/nodesSetup_test.go @@ -5,8 +5,11 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/sharding/mock" - "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/testscommon/chainParameters" ) var ( @@ -29,10 +32,10 @@ var ( } ) -func createAndAssignNodes(ns NodesSetup, noOfInitialNodes int) *NodesSetup { - ns.InitialNodes = make([]*InitialNode, noOfInitialNodes) +func createAndAssignNodes(ns *NodesSetup, numInitialNodes int) *NodesSetup { + ns.InitialNodes = make([]*InitialNode, numInitialNodes) - for i := 0; i < noOfInitialNodes; i++ { + for i := 0; i < numInitialNodes; i++ { lookupIndex := i % len(pubKeys) ns.InitialNodes[i] = &InitialNode{} ns.InitialNodes[i].PubKey = pubKeys[lookupIndex] @@ -48,178 +51,85 @@ func createAndAssignNodes(ns NodesSetup, noOfInitialNodes int) *NodesSetup { ns.processShardAssignment() ns.createInitialNodesInfo() - return &ns -} - -func createNodesSetupOneShardOneNodeWithOneMeta() *NodesSetup { - ns := &NodesSetup{ - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - genesisMaxNumShards: 100, - } - ns.ConsensusGroupSize = 1 - ns.MinNodesPerShard = 1 - ns.MetaChainConsensusGroupSize = 1 - ns.MetaChainMinNodes = 1 - ns.InitialNodes = []*InitialNode{ - { - PubKey: pubKeys[0], - Address: address[0], - }, - { - PubKey: pubKeys[1], - Address: address[1], - }, - } - - err := ns.processConfig() - if err != nil { - return nil - } - - ns.processMetaChainAssigment() - ns.processShardAssignment() - ns.createInitialNodesInfo() - - return ns -} - -func initNodesConfig(ns *NodesSetup, noOfInitialNodes int) bool { - ns.InitialNodes = make([]*InitialNode, noOfInitialNodes) - - for i := 0; i < noOfInitialNodes; i++ { - ns.InitialNodes[i] = &InitialNode{} - ns.InitialNodes[i].PubKey = pubKeys[i] - ns.InitialNodes[i].Address = address[i] - } - - err := ns.processConfig() - if err != nil { - return false - } - - ns.processMetaChainAssigment() - ns.processShardAssignment() - ns.createInitialNodesInfo() - return true -} - -func createNodesSetupTwoShardTwoNodesWithOneMeta() *NodesSetup { - noOfInitialNodes := 6 - ns := &NodesSetup{ - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - genesisMaxNumShards: 100, - } - ns.ConsensusGroupSize = 1 - ns.MinNodesPerShard = 2 - ns.MetaChainConsensusGroupSize = 1 - ns.MetaChainMinNodes = 2 - ok := initNodesConfig(ns, noOfInitialNodes) - if !ok { - return nil - } - return ns } -func createNodesSetupTwoShard5NodesWithMeta() *NodesSetup { - noOfInitialNodes := 5 - ns := &NodesSetup{ - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - genesisMaxNumShards: 100, - } - ns.ConsensusGroupSize = 1 - ns.MinNodesPerShard = 2 - ns.MetaChainConsensusGroupSize = 1 - ns.MetaChainMinNodes = 1 - ok := initNodesConfig(ns, noOfInitialNodes) - if !ok { - return nil - } - - return ns +type argsTestNodesSetup struct { + shardConsensusSize uint32 + shardMinNodes uint32 + metaConsensusSize uint32 + metaMinNodes uint32 + numInitialNodes uint32 + genesisMaxShards uint32 } -func createNodesSetupTwoShard6NodesMeta() *NodesSetup { - noOfInitialNodes := 6 - ns := &NodesSetup{ - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - genesisMaxNumShards: 100, - } - ns.ConsensusGroupSize = 1 - ns.MinNodesPerShard = 2 - ns.MetaChainMinNodes = 2 - ns.MetaChainConsensusGroupSize = 2 - ok := initNodesConfig(ns, noOfInitialNodes) - if !ok { - return nil +func createTestNodesSetup(args argsTestNodesSetup) (*NodesSetup, error) { + initialNodes := make([]*config.InitialNodeConfig, 0) + for i := 0; uint32(i) < args.numInitialNodes; i++ { + lookupIndex := i % len(pubKeys) + initialNodes = append(initialNodes, &config.InitialNodeConfig{ + PubKey: pubKeys[lookupIndex], + Address: address[lookupIndex], + }) } - - return ns -} - -func TestNodesSetup_NewNodesSetupWrongFile(t *testing.T) { - t.Parallel() - ns, err := NewNodesSetup( - "", - mock.NewPubkeyConverterMock(32), - mock.NewPubkeyConverterMock(96), - 100, - ) - - assert.Nil(t, ns) - assert.NotNil(t, err) -} - -func TestNodesSetup_NewNodesSetupWrongDataInFile(t *testing.T) { - t.Parallel() - - ns, err := NewNodesSetup( - "mock/testdata/invalidNodesSetupMock.json", + config.NodesConfig{ + StartTime: 0, + InitialNodes: initialNodes, + }, + &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(epoch uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + EnableEpoch: 0, + ShardMinNumNodes: args.shardMinNodes, + ShardConsensusGroupSize: args.shardConsensusSize, + MetachainMinNumNodes: args.metaMinNodes, + MetachainConsensusGroupSize: args.metaConsensusSize, + }, nil + }, + }, mock.NewPubkeyConverterMock(32), mock.NewPubkeyConverterMock(96), - 100, + args.genesisMaxShards, ) - assert.Nil(t, ns) - assert.Equal(t, ErrNegativeOrZeroConsensusGroupSize, err) + return ns, err } -func TestNodesSetup_NewNodesShouldWork(t *testing.T) { - t.Parallel() +func createTestNodesSetupWithSpecificMockedComponents(args argsTestNodesSetup, + initialNodes []*config.InitialNodeConfig, + addressPubkeyConverter core.PubkeyConverter, + validatorPubkeyConverter core.PubkeyConverter) (*NodesSetup, error) { ns, err := NewNodesSetup( - "mock/testdata/nodesSetupMock.json", - mock.NewPubkeyConverterMock(32), - mock.NewPubkeyConverterMock(96), - 100, + config.NodesConfig{ + StartTime: 0, + InitialNodes: initialNodes, + }, + &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(epoch uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{ + EnableEpoch: 0, + ShardMinNumNodes: args.shardMinNodes, + ShardConsensusGroupSize: args.shardConsensusSize, + MetachainMinNumNodes: args.metaMinNodes, + MetachainConsensusGroupSize: args.metaConsensusSize, + }, nil + }, + }, + addressPubkeyConverter, + validatorPubkeyConverter, + args.genesisMaxShards, ) - assert.NotNil(t, ns) - assert.Nil(t, err) - assert.Equal(t, 5, len(ns.InitialNodes)) -} - -func TestNodesSetup_InitialNodesPubKeysFromNil(t *testing.T) { - t.Parallel() - - ns := NodesSetup{} - eligible, waiting := ns.InitialNodesInfo() - - assert.NotNil(t, ns) - assert.Nil(t, eligible) - assert.Nil(t, waiting) + return ns, err } func TestNodesSetup_ProcessConfigNodesWithIncompleteDataShouldErr(t *testing.T) { t.Parallel() noOfInitialNodes := 2 - ns := NodesSetup{ + ns := &NodesSetup{ addressPubkeyConverter: mock.NewPubkeyConverterMock(32), validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), } @@ -234,496 +144,930 @@ func TestNodesSetup_ProcessConfigNodesWithIncompleteDataShouldErr(t *testing.T) err := ns.processConfig() - assert.NotNil(t, ns) - assert.Equal(t, ErrCouldNotParsePubKey, err) + require.NotNil(t, ns) + require.Equal(t, ErrCouldNotParsePubKey, err) } -func TestNodesSetup_ProcessConfigInvalidConsensusGroupSizeShouldErr(t *testing.T) { +func TestNodesSetup_ProcessConfigNodesShouldErrCouldNotParsePubKeyForString(t *testing.T) { t.Parallel() - noOfInitialNodes := 2 - ns := NodesSetup{ - ConsensusGroupSize: 0, - MinNodesPerShard: 0, - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), + mockedNodes := make([]*config.InitialNodeConfig, 2) + mockedNodes[0] = &config.InitialNodeConfig{ + PubKey: pubKeys[0], + Address: address[0], } - ns.InitialNodes = make([]*InitialNode, noOfInitialNodes) + mockedNodes[1] = &config.InitialNodeConfig{ + PubKey: pubKeys[1], + Address: address[1], + } - for i := 0; i < noOfInitialNodes; i++ { - ns.InitialNodes[i] = &InitialNode{} - ns.InitialNodes[i].PubKey = pubKeys[i] - ns.InitialNodes[i].Address = address[i] + addressPubkeyConverterMocked := mock.NewPubkeyConverterMock(32) + validatorPubkeyConverterMocked := &mock.PubkeyConverterMock{ + DecodeCalled: func() ([]byte, error) { + return nil, ErrCouldNotParsePubKey + }, } - err := ns.processConfig() + _, err := createTestNodesSetupWithSpecificMockedComponents(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 2, + genesisMaxShards: 3, + }, + mockedNodes, + addressPubkeyConverterMocked, + validatorPubkeyConverterMocked, + ) - assert.NotNil(t, ns) - assert.Equal(t, ErrNegativeOrZeroConsensusGroupSize, err) + require.ErrorIs(t, err, ErrCouldNotParsePubKey) } -func TestNodesSetup_ProcessConfigInvalidMetaConsensusGroupSizeShouldErr(t *testing.T) { +func TestNodesSetup_ProcessConfigNodesShouldErrCouldNotParseAddressForString(t *testing.T) { t.Parallel() - noOfInitialNodes := 2 - ns := NodesSetup{ - ConsensusGroupSize: 1, - MinNodesPerShard: 1, - MetaChainConsensusGroupSize: 0, - MetaChainMinNodes: 0, - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), + mockedNodes := make([]*config.InitialNodeConfig, 2) + mockedNodes[0] = &config.InitialNodeConfig{ + PubKey: pubKeys[0], + Address: address[0], } - ns.InitialNodes = make([]*InitialNode, noOfInitialNodes) - - for i := 0; i < noOfInitialNodes; i++ { - ns.InitialNodes[i] = &InitialNode{} - ns.InitialNodes[i].PubKey = pubKeys[i] - ns.InitialNodes[i].Address = address[i] + mockedNodes[1] = &config.InitialNodeConfig{ + PubKey: pubKeys[1], + Address: address[1], } - err := ns.processConfig() + addressPubkeyConverterMocked := &mock.PubkeyConverterMock{ + DecodeCalled: func() ([]byte, error) { + return nil, ErrCouldNotParseAddress + }, + } + validatorPubkeyConverterMocked := mock.NewPubkeyConverterMock(96) + + _, err := createTestNodesSetupWithSpecificMockedComponents(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 2, + genesisMaxShards: 3, + }, + mockedNodes, + addressPubkeyConverterMocked, + validatorPubkeyConverterMocked, + ) - assert.NotNil(t, ns) - assert.Equal(t, ErrNegativeOrZeroConsensusGroupSize, err) + require.ErrorIs(t, err, ErrCouldNotParseAddress) } -func TestNodesSetup_ProcessConfigInvalidConsensusGroupSizeLargerThanNumOfNodesShouldErr(t *testing.T) { +func TestNodesSetup_ProcessConfigNodesWithEmptyDataShouldErrCouldNotParseAddress(t *testing.T) { t.Parallel() - noOfInitialNodes := 2 - ns := NodesSetup{ - ConsensusGroupSize: 2, - MinNodesPerShard: 0, - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), + mockedNodes := make([]*config.InitialNodeConfig, 2) + mockedNodes[0] = &config.InitialNodeConfig{ + PubKey: pubKeys[0], + Address: address[0], } - ns.InitialNodes = make([]*InitialNode, noOfInitialNodes) - - for i := 0; i < noOfInitialNodes; i++ { - ns.InitialNodes[i] = &InitialNode{} - ns.InitialNodes[i].PubKey = pubKeys[i] - ns.InitialNodes[i].Address = address[i] + mockedNodes[1] = &config.InitialNodeConfig{ + PubKey: pubKeys[1], + Address: "", } - err := ns.processConfig() + addressPubkeyConverterMocked := mock.NewPubkeyConverterMock(32) + validatorPubkeyConverterMocked := mock.NewPubkeyConverterMock(96) + + _, err := createTestNodesSetupWithSpecificMockedComponents(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 2, + genesisMaxShards: 3, + }, + mockedNodes, + addressPubkeyConverterMocked, + validatorPubkeyConverterMocked, + ) - assert.NotNil(t, ns) - assert.Equal(t, ErrMinNodesPerShardSmallerThanConsensusSize, err) + require.ErrorIs(t, err, ErrCouldNotParseAddress) } -func TestNodesSetup_ProcessConfigInvalidMetaConsensusGroupSizeLargerThanNumOfNodesShouldErr(t *testing.T) { +func TestNodesSetup_ProcessConfigInvalidConsensusGroupSizeShouldErr(t *testing.T) { t.Parallel() - noOfInitialNodes := 2 - ns := NodesSetup{ - ConsensusGroupSize: 1, - MinNodesPerShard: 1, - MetaChainConsensusGroupSize: 1, - MetaChainMinNodes: 0, - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - } - - ns.InitialNodes = make([]*InitialNode, 2) - - for i := 0; i < noOfInitialNodes; i++ { - ns.InitialNodes[i] = &InitialNode{} - ns.InitialNodes[i].PubKey = pubKeys[i] - ns.InitialNodes[i].Address = address[i] - } - - err := ns.processConfig() - - assert.NotNil(t, ns) - assert.Equal(t, ErrMinNodesPerShardSmallerThanConsensusSize, err) + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 0, + shardMinNodes: 0, + metaConsensusSize: 0, + metaMinNodes: 0, + numInitialNodes: 0, + genesisMaxShards: 3, + }) + require.Equal(t, ErrNegativeOrZeroConsensusGroupSize, err) + require.Nil(t, ns) } -func TestNodesSetup_ProcessConfigInvalidMinNodesPerShardShouldErr(t *testing.T) { +func TestNodesSetup_ProcessConfigInvalidMetaConsensusGroupSizeShouldErr(t *testing.T) { t.Parallel() - noOfInitialNodes := 2 - ns := NodesSetup{ - ConsensusGroupSize: 2, - MinNodesPerShard: 0, - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - } - - ns.InitialNodes = make([]*InitialNode, noOfInitialNodes) - - for i := 0; i < noOfInitialNodes; i++ { - ns.InitialNodes[i] = &InitialNode{} - ns.InitialNodes[i].PubKey = pubKeys[i] - ns.InitialNodes[i].Address = address[i] - } - - err := ns.processConfig() - - assert.NotNil(t, ns) - assert.Equal(t, ErrMinNodesPerShardSmallerThanConsensusSize, err) + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 0, + metaMinNodes: 0, + numInitialNodes: 1, + genesisMaxShards: 3, + }) + require.Equal(t, ErrNegativeOrZeroConsensusGroupSize, err) + require.Nil(t, ns) } -func TestNodesSetup_ProcessConfigInvalidMetaMinNodesPerShardShouldErr(t *testing.T) { +func TestNodesSetup_ProcessConfigInvalidConsensusGroupSizeLargerThanNumOfNodesShouldErr(t *testing.T) { t.Parallel() - noOfInitialNodes := 1 - ns := NodesSetup{ - ConsensusGroupSize: 1, - MinNodesPerShard: 1, - MetaChainConsensusGroupSize: 1, - MetaChainMinNodes: 0, - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - } - - ns.InitialNodes = make([]*InitialNode, noOfInitialNodes) - - for i := 0; i < noOfInitialNodes; i++ { - ns.InitialNodes[i] = &InitialNode{} - ns.InitialNodes[i].PubKey = pubKeys[i] - ns.InitialNodes[i].Address = address[i] - } - - err := ns.processConfig() - - assert.NotNil(t, ns) - assert.Equal(t, ErrMinNodesPerShardSmallerThanConsensusSize, err) + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 2, + shardMinNodes: 0, + metaConsensusSize: 0, + metaMinNodes: 0, + numInitialNodes: 2, + genesisMaxShards: 3, + }) + require.Equal(t, ErrMinNodesPerShardSmallerThanConsensusSize, err) + require.Nil(t, ns) } -func TestNodesSetup_ProcessConfigInvalidNumOfNodesSmallerThanMinNodesPerShardShouldErr(t *testing.T) { +func TestNodesSetup_ProcessConfigInvalidMetaConsensusGroupSizeLargerThanNumOfNodesShouldErr(t *testing.T) { t.Parallel() - noOfInitialNodes := 2 - ns := NodesSetup{ - ConsensusGroupSize: 2, - MinNodesPerShard: 3, - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - } - - ns.InitialNodes = make([]*InitialNode, noOfInitialNodes) - - for i := 0; i < noOfInitialNodes; i++ { - ns.InitialNodes[i] = &InitialNode{} - ns.InitialNodes[i].PubKey = pubKeys[i] - ns.InitialNodes[i].Address = address[i] - } - - err := ns.processConfig() - - assert.NotNil(t, ns) - assert.Equal(t, ErrNodesSizeSmallerThanMinNoOfNodes, err) + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 2, + metaMinNodes: 1, + numInitialNodes: 2, + genesisMaxShards: 3, + }) + require.Equal(t, ErrMinNodesPerShardSmallerThanConsensusSize, err) + require.Nil(t, ns) } -func TestNodesSetup_ProcessConfigInvalidMetaNumOfNodesSmallerThanMinNodesPerShardShouldErr(t *testing.T) { +func TestNodesSetup_ProcessConfigInvalidNumOfNodesSmallerThanMinNodesPerShardShouldErr(t *testing.T) { t.Parallel() - noOfInitialNodes := 3 - ns := NodesSetup{ - ConsensusGroupSize: 1, - MinNodesPerShard: 1, - MetaChainConsensusGroupSize: 2, - MetaChainMinNodes: 3, - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - } - - ns.InitialNodes = make([]*InitialNode, noOfInitialNodes) - - for i := 0; i < noOfInitialNodes; i++ { - ns.InitialNodes[i] = &InitialNode{} - ns.InitialNodes[i].PubKey = pubKeys[i] - ns.InitialNodes[i].Address = address[i] - } - - err := ns.processConfig() - - assert.NotNil(t, ns) - assert.Equal(t, ErrNodesSizeSmallerThanMinNoOfNodes, err) + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 2, + shardMinNodes: 3, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 2, + genesisMaxShards: 3, + }) + require.Nil(t, ns) + require.Equal(t, ErrNodesSizeSmallerThanMinNoOfNodes, err) } -func TestNodesSetup_InitialNodesPubKeysForShardNil(t *testing.T) { +func TestNodesSetup_ProcessConfigInvalidNumOfNodesSmallerThanTotalMinNodesShouldErr(t *testing.T) { t.Parallel() - ns := NodesSetup{ - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - } - eligible, waiting, err := ns.InitialNodesInfoForShard(0) - - assert.NotNil(t, ns) - assert.Nil(t, eligible) - assert.Nil(t, waiting) - assert.NotNil(t, err) + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 2, + shardMinNodes: 3, + metaConsensusSize: 1, + metaMinNodes: 3, + numInitialNodes: 5, + genesisMaxShards: 3, + }) + require.Nil(t, ns) + require.Equal(t, ErrNodesSizeSmallerThanMinNoOfNodes, err) } func TestNodesSetup_InitialNodesPubKeysWithHysteresis(t *testing.T) { t.Parallel() - ns := &NodesSetup{ - ConsensusGroupSize: 63, - MinNodesPerShard: 400, - MetaChainConsensusGroupSize: 400, - MetaChainMinNodes: 400, - Hysteresis: 0.2, - Adaptivity: false, - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - genesisMaxNumShards: 100, - } - - ns = createAndAssignNodes(*ns, 3000) - - assert.Equal(t, 6, len(ns.eligible)) + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 63, + shardMinNodes: 400, + metaConsensusSize: 400, + metaMinNodes: 400, + numInitialNodes: 3000, + genesisMaxShards: 100, + }) + ns.Hysteresis = 0.2 + ns.Adaptivity = false + require.NoError(t, err) + + ns = createAndAssignNodes(ns, 3000) + require.Equal(t, 6, len(ns.eligible)) for shard, shardNodes := range ns.eligible { - assert.Equal(t, 400, len(shardNodes)) - assert.Equal(t, 100, len(ns.waiting[shard])) + require.Equal(t, 400, len(shardNodes)) + require.Equal(t, 100, len(ns.waiting[shard])) } - ns = createAndAssignNodes(*ns, 3570) - assert.Equal(t, 7, len(ns.eligible)) + ns = createAndAssignNodes(ns, 3570) + require.Equal(t, 7, len(ns.eligible)) for shard, shardNodes := range ns.eligible { - assert.Equal(t, 400, len(shardNodes)) - assert.Equal(t, 110, len(ns.waiting[shard])) + require.Equal(t, 400, len(shardNodes)) + require.Equal(t, 110, len(ns.waiting[shard])) } - ns = createAndAssignNodes(*ns, 2400) - assert.Equal(t, 5, len(ns.eligible)) + ns = createAndAssignNodes(ns, 2400) + require.Equal(t, 5, len(ns.eligible)) for shard, shardNodes := range ns.eligible { - assert.Equal(t, 400, len(shardNodes)) - assert.Equal(t, 80, len(ns.waiting[shard])) + require.Equal(t, 400, len(shardNodes)) + require.Equal(t, 80, len(ns.waiting[shard])) } } func TestNodesSetup_InitialNodesPubKeysForShardWrongShard(t *testing.T) { t.Parallel() - ns := createNodesSetupOneShardOneNodeWithOneMeta() + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 2, + genesisMaxShards: 3, + }) + require.NoError(t, err) eligible, waiting, err := ns.InitialNodesInfoForShard(1) - assert.NotNil(t, ns) - assert.Nil(t, eligible) - assert.Nil(t, waiting) - assert.NotNil(t, err) + require.NotNil(t, ns) + require.Nil(t, eligible) + require.Nil(t, waiting) + require.NotNil(t, err) } func TestNodesSetup_InitialNodesPubKeysForShardGood(t *testing.T) { t.Parallel() - ns := createNodesSetupTwoShardTwoNodesWithOneMeta() + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 2, + metaConsensusSize: 1, + metaMinNodes: 2, + numInitialNodes: 6, + genesisMaxShards: 3, + }) + require.NoError(t, err) + eligible, waiting, err := ns.InitialNodesInfoForShard(1) - assert.NotNil(t, ns) - assert.Equal(t, 2, len(eligible)) - assert.Equal(t, 0, len(waiting)) - assert.Nil(t, err) + require.NotNil(t, ns) + require.Equal(t, 2, len(eligible)) + require.Equal(t, 0, len(waiting)) + require.Nil(t, err) } func TestNodesSetup_InitialNodesPubKeysForShardGoodMeta(t *testing.T) { t.Parallel() - ns := createNodesSetupTwoShard6NodesMeta() + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 2, + metaConsensusSize: 2, + metaMinNodes: 2, + numInitialNodes: 6, + genesisMaxShards: 3, + }) + require.NoError(t, err) metaId := core.MetachainShardId eligible, waiting, err := ns.InitialNodesInfoForShard(metaId) - assert.NotNil(t, ns) - assert.Equal(t, 2, len(eligible)) - assert.Equal(t, 0, len(waiting)) - assert.Nil(t, err) + require.NotNil(t, ns) + require.Equal(t, 2, len(eligible)) + require.Equal(t, 0, len(waiting)) + require.Nil(t, err) } func TestNodesSetup_PublicKeyNotGood(t *testing.T) { t.Parallel() - ns := createNodesSetupTwoShard6NodesMeta() + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 5, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 6, + genesisMaxShards: 3, + }) + require.NoError(t, err) - _, err := ns.GetShardIDForPubKey([]byte(pubKeys[0])) + _, err = ns.GetShardIDForPubKey([]byte(pubKeys[0])) - assert.NotNil(t, ns) - assert.NotNil(t, err) + require.NotNil(t, ns) + require.NotNil(t, err) } func TestNodesSetup_PublicKeyGood(t *testing.T) { t.Parallel() - ns := createNodesSetupTwoShard5NodesWithMeta() + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 5, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 6, + genesisMaxShards: 3, + }) + require.NoError(t, err) + publicKey, _ := hex.DecodeString(pubKeys[2]) selfId, err := ns.GetShardIDForPubKey(publicKey) - assert.NotNil(t, ns) - assert.Nil(t, err) - assert.Equal(t, uint32(0), selfId) + require.NotNil(t, ns) + require.Nil(t, err) + require.Equal(t, uint32(0), selfId) } func TestNodesSetup_ShardPublicKeyGoodMeta(t *testing.T) { t.Parallel() - ns := createNodesSetupTwoShard6NodesMeta() + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 5, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 6, + genesisMaxShards: 3, + }) + require.NoError(t, err) publicKey, _ := hex.DecodeString(pubKeys[2]) selfId, err := ns.GetShardIDForPubKey(publicKey) - assert.NotNil(t, ns) - assert.Nil(t, err) - assert.Equal(t, uint32(0), selfId) + require.NotNil(t, ns) + require.Nil(t, err) + require.Equal(t, uint32(0), selfId) } func TestNodesSetup_MetaPublicKeyGoodMeta(t *testing.T) { t.Parallel() - ns := createNodesSetupTwoShard6NodesMeta() + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 5, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 6, + genesisMaxShards: 3, + }) + require.NoError(t, err) metaId := core.MetachainShardId publicKey, _ := hex.DecodeString(pubKeys[0]) selfId, err := ns.GetShardIDForPubKey(publicKey) - assert.NotNil(t, ns) - assert.Nil(t, err) - assert.Equal(t, metaId, selfId) + require.NotNil(t, ns) + require.Nil(t, err) + require.Equal(t, metaId, selfId) } func TestNodesSetup_MinNumberOfNodes(t *testing.T) { t.Parallel() - ns := &NodesSetup{ - ConsensusGroupSize: 63, - MinNodesPerShard: 400, - MetaChainConsensusGroupSize: 400, - MetaChainMinNodes: 400, - Hysteresis: 0.2, - Adaptivity: false, - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - genesisMaxNumShards: 100, - } - - ns = createAndAssignNodes(*ns, 2169) - assert.Equal(t, 4, len(ns.eligible)) + + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 63, + shardMinNodes: 400, + metaConsensusSize: 400, + metaMinNodes: 400, + numInitialNodes: 2169, + genesisMaxShards: 3, + }) + ns.Hysteresis = 0.2 + ns.Adaptivity = false + require.NoError(t, err) + + ns = createAndAssignNodes(ns, 2169) + require.Equal(t, 4, len(ns.eligible)) for shard, shardNodes := range ns.eligible { - assert.Equal(t, 400, len(shardNodes)) - assert.LessOrEqual(t, len(ns.waiting[shard]), 143) - assert.GreaterOrEqual(t, len(ns.waiting[shard]), 142) + require.Equal(t, 400, len(shardNodes)) + require.LessOrEqual(t, len(ns.waiting[shard]), 143) + require.GreaterOrEqual(t, len(ns.waiting[shard]), 142) } minNumNodes := ns.MinNumberOfNodes() - assert.Equal(t, uint32(1600), minNumNodes) + require.Equal(t, uint32(1600), minNumNodes) minHysteresisNodesShard := ns.MinShardHysteresisNodes() - assert.Equal(t, uint32(80), minHysteresisNodesShard) + require.Equal(t, uint32(80), minHysteresisNodesShard) minHysteresisNodesMeta := ns.MinMetaHysteresisNodes() - assert.Equal(t, uint32(80), minHysteresisNodesMeta) + require.Equal(t, uint32(80), minHysteresisNodesMeta) } func TestNewNodesSetup_InvalidMaxNumShardsShouldErr(t *testing.T) { t.Parallel() ns, err := NewNodesSetup( - "", + config.NodesConfig{}, + &chainParameters.ChainParametersHandlerStub{}, mock.NewPubkeyConverterMock(32), mock.NewPubkeyConverterMock(96), 0, ) - assert.Nil(t, ns) - assert.NotNil(t, err) - assert.Contains(t, err.Error(), ErrInvalidMaximumNumberOfShards.Error()) + require.Nil(t, ns) + require.NotNil(t, err) + require.Contains(t, err.Error(), ErrInvalidMaximumNumberOfShards.Error()) } -func TestNodesSetup_IfNodesWithinMaxShardLimitEquivalentDistribution(t *testing.T) { +func TestNewNodesSetup_ErrNilPubkeyConverterForAddressPubkeyConverter(t *testing.T) { t.Parallel() - ns := &NodesSetup{ - ConsensusGroupSize: 63, - MinNodesPerShard: 400, - MetaChainConsensusGroupSize: 400, - MetaChainMinNodes: 400, - Hysteresis: 0.2, - Adaptivity: false, - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - genesisMaxNumShards: 100, - } + _, err := NewNodesSetup( + config.NodesConfig{}, + &chainParameters.ChainParametersHandlerStub{}, + nil, + mock.NewPubkeyConverterMock(96), + 3, + ) + + require.ErrorIs(t, err, ErrNilPubkeyConverter) +} + +func TestNewNodesSetup_ErrNilPubkeyConverterForValidatorPubkeyConverter(t *testing.T) { + t.Parallel() + + _, err := NewNodesSetup( + config.NodesConfig{}, + &chainParameters.ChainParametersHandlerStub{}, + mock.NewPubkeyConverterMock(32), + nil, + 3, + ) + + require.ErrorIs(t, err, ErrNilPubkeyConverter) +} + +func TestNewNodesSetup_ErrNilChainParametersProvider(t *testing.T) { + t.Parallel() + + _, err := NewNodesSetup( + config.NodesConfig{}, + nil, + mock.NewPubkeyConverterMock(32), + mock.NewPubkeyConverterMock(96), + 3, + ) - ns = createAndAssignNodes(*ns, 2169) + require.Equal(t, err, ErrNilChainParametersProvider) +} + +func TestNewNodesSetup_ErrChainParametersForEpoch(t *testing.T) { + t.Parallel() + + _, err := NewNodesSetup( + config.NodesConfig{}, + &chainParameters.ChainParametersHandlerStub{ + ChainParametersForEpochCalled: func(epoch uint32) (config.ChainParametersByEpochConfig, error) { + return config.ChainParametersByEpochConfig{}, ErrInvalidChainParametersForEpoch + }, + }, + mock.NewPubkeyConverterMock(32), + mock.NewPubkeyConverterMock(96), + 3, + ) + + require.ErrorIs(t, err, ErrInvalidChainParametersForEpoch) +} + +func TestNodesSetup_IfNodesWithinMaxShardLimitEquivalentDistribution(t *testing.T) { + t.Parallel() + + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 64, + shardMinNodes: 400, + metaConsensusSize: 400, + metaMinNodes: 400, + numInitialNodes: 2169, + genesisMaxShards: 3, + }) + ns.Hysteresis = 0.2 + ns.Adaptivity = false + require.NoError(t, err) ns2 := &(*ns) //nolint ns2.genesisMaxNumShards = 3 - ns2 = createAndAssignNodes(*ns2, 2169) + ns2 = createAndAssignNodes(ns2, 2169) - assert.Equal(t, 4, len(ns.eligible)) - assert.Equal(t, 4, len(ns2.eligible)) + require.Equal(t, 4, len(ns.eligible)) + require.Equal(t, 4, len(ns2.eligible)) for shard, shardNodes := range ns.eligible { - assert.Equal(t, len(shardNodes), len(ns2.eligible[shard])) - assert.Equal(t, len(ns.waiting[shard]), len(ns2.waiting[shard])) - assert.GreaterOrEqual(t, len(ns.waiting[shard]), 142) - assert.Equal(t, len(ns.waiting[shard]), len(ns2.waiting[shard])) + require.Equal(t, len(shardNodes), len(ns2.eligible[shard])) + require.Equal(t, len(ns.waiting[shard]), len(ns2.waiting[shard])) + require.GreaterOrEqual(t, len(ns.waiting[shard]), 142) + require.Equal(t, len(ns.waiting[shard]), len(ns2.waiting[shard])) for i, node := range shardNodes { - assert.Equal(t, node, ns2.eligible[shard][i]) + require.Equal(t, node, ns2.eligible[shard][i]) } for i, node := range ns.waiting[shard] { - assert.Equal(t, node, ns2.waiting[shard][i]) + require.Equal(t, node, ns2.waiting[shard][i]) } } minNumNodes := ns.MinNumberOfNodes() - assert.Equal(t, minNumNodes, ns2.MinNumberOfNodes()) + require.Equal(t, minNumNodes, ns2.MinNumberOfNodes()) minHysteresisNodesShard := ns.MinShardHysteresisNodes() - assert.Equal(t, minHysteresisNodesShard, ns2.MinShardHysteresisNodes()) + require.Equal(t, minHysteresisNodesShard, ns2.MinShardHysteresisNodes()) minHysteresisNodesMeta := ns.MinMetaHysteresisNodes() - assert.Equal(t, minHysteresisNodesMeta, ns2.MinMetaHysteresisNodes()) + require.Equal(t, minHysteresisNodesMeta, ns2.MinMetaHysteresisNodes()) } func TestNodesSetup_NodesAboveMaxShardLimit(t *testing.T) { t.Parallel() - ns := &NodesSetup{ - ConsensusGroupSize: 63, - MinNodesPerShard: 400, - MetaChainConsensusGroupSize: 400, - MetaChainMinNodes: 400, - Hysteresis: 0.2, - Adaptivity: false, - addressPubkeyConverter: mock.NewPubkeyConverterMock(32), - validatorPubkeyConverter: mock.NewPubkeyConverterMock(96), - genesisMaxNumShards: 3, - } - - ns = createAndAssignNodes(*ns, 3200) - - assert.Equal(t, 4, len(ns.eligible)) + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 63, + shardMinNodes: 400, + metaConsensusSize: 400, + metaMinNodes: 400, + numInitialNodes: 3200, + genesisMaxShards: 3, + }) + ns.Hysteresis = 0.2 + ns.Adaptivity = false + require.NoError(t, err) + + require.Equal(t, 4, len(ns.eligible)) for shard, shardNodes := range ns.eligible { - assert.Equal(t, 400, len(shardNodes)) - assert.Equal(t, len(ns.waiting[shard]), 400) + require.Equal(t, 400, len(shardNodes)) + require.Equal(t, len(ns.waiting[shard]), 400) } minNumNodes := ns.MinNumberOfNodes() - assert.Equal(t, uint32(1600), minNumNodes) + require.Equal(t, uint32(1600), minNumNodes) minHysteresisNodesShard := ns.MinShardHysteresisNodes() - assert.Equal(t, uint32(80), minHysteresisNodesShard) + require.Equal(t, uint32(80), minHysteresisNodesShard) minHysteresisNodesMeta := ns.MinMetaHysteresisNodes() - assert.Equal(t, uint32(80), minHysteresisNodesMeta) + require.Equal(t, uint32(80), minHysteresisNodesMeta) - ns = createAndAssignNodes(*ns, 3600) + ns = createAndAssignNodes(ns, 3600) for shard, shardNodes := range ns.eligible { - assert.Equal(t, 400, len(shardNodes)) - assert.Equal(t, len(ns.waiting[shard]), 500) + require.Equal(t, 400, len(shardNodes)) + require.Equal(t, len(ns.waiting[shard]), 500) } minNumNodes = ns.MinNumberOfNodes() - assert.Equal(t, uint32(1600), minNumNodes) + require.Equal(t, uint32(1600), minNumNodes) minHysteresisNodesShard = ns.MinShardHysteresisNodes() - assert.Equal(t, uint32(80), minHysteresisNodesShard) + require.Equal(t, uint32(80), minHysteresisNodesShard) minHysteresisNodesMeta = ns.MinMetaHysteresisNodes() - assert.Equal(t, uint32(80), minHysteresisNodesMeta) + require.Equal(t, uint32(80), minHysteresisNodesMeta) +} + +func TestNodesSetup_AllInitialNodesShouldWork(t *testing.T) { + t.Parallel() + + noOfInitialNodes := 2 + + var listOfInitialNodes = [2]InitialNode{ + { + PubKey: pubKeys[0], + Address: address[0], + }, + { + PubKey: pubKeys[1], + Address: address[1], + }, + } + + var expectedConvertedPubKeys = make([][]byte, 2) + pubKeyConverter := mock.NewPubkeyConverterMock(96) + + for i, nod := range listOfInitialNodes { + convertedValue, err := pubKeyConverter.Decode(nod.PubKey) + require.Nil(t, err) + require.NotNil(t, convertedValue) + expectedConvertedPubKeys[i] = convertedValue + } + + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 2, + genesisMaxShards: 1, + }) + + require.Nil(t, err) + ns.Hysteresis = 0.2 + ns.Adaptivity = false + + ns = createAndAssignNodes(ns, noOfInitialNodes) + + allInitialNodes := ns.AllInitialNodes() + + for i, expectedConvertedKey := range expectedConvertedPubKeys { + require.Equal(t, expectedConvertedKey, allInitialNodes[i].PubKeyBytes()) + } + +} + +func TestNodesSetup_InitialNodesInfoShouldWork(t *testing.T) { + t.Parallel() + + noOfInitialNodes := 3 + + var listOfInitialNodes = [3]InitialNode{ + { + PubKey: pubKeys[0], + Address: address[0], + }, + { + PubKey: pubKeys[1], + Address: address[1], + }, + { + PubKey: pubKeys[2], + Address: address[2], + }, + } + + var listOfExpectedConvertedPubKeysEligibleNodes = make([][]byte, 2) + pubKeyConverter := mock.NewPubkeyConverterMock(96) + + for i := 0; i < 2; i++ { + convertedValue, err := pubKeyConverter.Decode(listOfInitialNodes[i].PubKey) + require.Nil(t, err) + require.NotNil(t, convertedValue) + listOfExpectedConvertedPubKeysEligibleNodes[i] = convertedValue + } + + var listOfExpectedConvertedPubKeysWaitingNode = make([][]byte, 1) + listOfExpectedConvertedPubKeysWaitingNode[0], _ = pubKeyConverter.Decode(listOfInitialNodes[2].PubKey) + + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 3, + genesisMaxShards: 1, + }) + require.Nil(t, err) + ns.Hysteresis = 0.2 + ns.Adaptivity = false + + ns = createAndAssignNodes(ns, noOfInitialNodes) + + allEligibleNodes, allWaitingNodes := ns.InitialNodesInfo() + + require.Equal(t, listOfExpectedConvertedPubKeysEligibleNodes[0], allEligibleNodes[(core.MetachainShardId)][0].PubKeyBytes()) + require.Equal(t, listOfExpectedConvertedPubKeysEligibleNodes[1], allEligibleNodes[0][0].PubKeyBytes()) + require.Equal(t, listOfExpectedConvertedPubKeysWaitingNode[0], allWaitingNodes[(core.MetachainShardId)][0].PubKeyBytes()) + +} + +func TestNodesSetup_InitialNodesPubKeysShouldWork(t *testing.T) { + t.Parallel() + + noOfInitialNodes := 3 + + var listOfInitialNodes = [3]InitialNode{ + { + PubKey: pubKeys[0], + Address: address[0], + }, + { + PubKey: pubKeys[1], + Address: address[1], + }, + { + PubKey: pubKeys[2], + Address: address[2], + }, + } + + var listOfExpectedConvertedPubKeysEligibleNodes = make([]string, 2) + pubKeyConverter := mock.NewPubkeyConverterMock(96) + + for i := 0; i < 2; i++ { + convertedValue, err := pubKeyConverter.Decode(listOfInitialNodes[i].PubKey) + require.Nil(t, err) + require.NotNil(t, convertedValue) + listOfExpectedConvertedPubKeysEligibleNodes[i] = string(convertedValue) + } + + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 3, + genesisMaxShards: 1, + }) + require.Nil(t, err) + ns.Hysteresis = 0.2 + ns.Adaptivity = false + + ns = createAndAssignNodes(ns, noOfInitialNodes) + + allEligibleNodes := ns.InitialNodesPubKeys() + + require.Equal(t, listOfExpectedConvertedPubKeysEligibleNodes[0], allEligibleNodes[(core.MetachainShardId)][0]) + require.Equal(t, listOfExpectedConvertedPubKeysEligibleNodes[1], allEligibleNodes[0][0]) + +} + +func TestNodesSetup_InitialEligibleNodesPubKeysForShardShouldErrShardIdOutOfRange(t *testing.T) { + t.Parallel() + + noOfInitialNodes := 3 + + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 3, + genesisMaxShards: 1, + }) + require.Nil(t, err) + ns.Hysteresis = 0.2 + ns.Adaptivity = false + + ns = createAndAssignNodes(ns, noOfInitialNodes) + + returnedPubKeys, err := ns.InitialEligibleNodesPubKeysForShard(1) + require.Nil(t, returnedPubKeys) + require.Equal(t, ErrShardIdOutOfRange, err) + +} + +func TestNodesSetup_InitialEligibleNodesPubKeysForShardShouldWork(t *testing.T) { + t.Parallel() + + noOfInitialNodes := 3 + + var listOfInitialNodes = [3]InitialNode{ + { + PubKey: pubKeys[0], + Address: address[0], + }, + { + PubKey: pubKeys[1], + Address: address[1], + }, + { + PubKey: pubKeys[2], + Address: address[2], + }, + } + + var listOfExpectedPubKeysEligibleNodes = make([]string, 2) + pubKeyConverter := mock.NewPubkeyConverterMock(96) + + for i := 0; i < 2; i++ { + convertedValue, err := pubKeyConverter.Decode(listOfInitialNodes[i].PubKey) + require.Nil(t, err) + require.NotNil(t, convertedValue) + listOfExpectedPubKeysEligibleNodes[i] = string(convertedValue) + } + + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 3, + genesisMaxShards: 1, + }) + require.Nil(t, err) + ns.Hysteresis = 0.2 + ns.Adaptivity = false + + ns = createAndAssignNodes(ns, noOfInitialNodes) + + allEligibleNodes, err := ns.InitialEligibleNodesPubKeysForShard(0) + + require.Nil(t, err) + require.Equal(t, listOfExpectedPubKeysEligibleNodes[1], allEligibleNodes[0]) +} + +func TestNodesSetup_NumberOfShardsShouldWork(t *testing.T) { + t.Parallel() + + noOfInitialNodes := 3 + + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 3, + genesisMaxShards: 1, + }) + require.Nil(t, err) + require.NotNil(t, ns) + + ns.Hysteresis = 0.2 + ns.Adaptivity = false + + ns = createAndAssignNodes(ns, noOfInitialNodes) + + require.NotNil(t, ns) + + valReturned := ns.NumberOfShards() + require.Equal(t, uint32(1), valReturned) + + valReturned = ns.MinNumberOfNodesWithHysteresis() + require.Equal(t, uint32(2), valReturned) + + valReturned = ns.MinNumberOfShardNodes() + require.Equal(t, uint32(1), valReturned) + + valReturned = ns.MinNumberOfShardNodes() + require.Equal(t, uint32(1), valReturned) + + shardConsensusGroupSize := ns.GetShardConsensusGroupSize() + require.Equal(t, uint32(1), shardConsensusGroupSize) + + metaConsensusGroupSize := ns.GetMetaConsensusGroupSize() + require.Equal(t, uint32(1), metaConsensusGroupSize) + + ns.Hysteresis = 0.5 + hysteresis := ns.GetHysteresis() + require.Equal(t, float32(0.5), hysteresis) + + ns.Adaptivity = true + adaptivity := ns.GetAdaptivity() + require.True(t, adaptivity) + + ns.StartTime = 2 + startTime := ns.GetStartTime() + require.Equal(t, int64(2), startTime) + + ns.RoundDuration = 2 + roundDuration := ns.GetRoundDuration() + require.Equal(t, uint64(2), roundDuration) + +} + +func TestNodesSetup_ExportNodesConfigShouldWork(t *testing.T) { + t.Parallel() + + noOfInitialNodes := 3 + + ns, err := createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 3, + genesisMaxShards: 1, + }) + require.Nil(t, err) + + ns.Hysteresis = 0.2 + ns.Adaptivity = false + ns.StartTime = 10 + + ns = createAndAssignNodes(ns, noOfInitialNodes) + configNodes := ns.ExportNodesConfig() + + require.Equal(t, int64(10), configNodes.StartTime) + + var expectedNodesConfigs = make([]config.InitialNodeConfig, len(configNodes.InitialNodes)) + var actualNodesConfigs = make([]config.InitialNodeConfig, len(configNodes.InitialNodes)) + + for i, nodeConfig := range configNodes.InitialNodes { + expectedNodesConfigs[i] = config.InitialNodeConfig{PubKey: pubKeys[i], Address: address[i], InitialRating: 0} + actualNodesConfigs[i] = config.InitialNodeConfig{PubKey: nodeConfig.PubKey, Address: nodeConfig.Address, InitialRating: nodeConfig.InitialRating} + + } + + for i := range configNodes.InitialNodes { + require.Equal(t, expectedNodesConfigs[i], actualNodesConfigs[i]) + } + +} + +func TestNodesSetup_IsInterfaceNil(t *testing.T) { + t.Parallel() + + ns, _ := NewNodesSetup(config.NodesConfig{}, nil, nil, nil, 0) + require.True(t, ns.IsInterfaceNil()) + + ns, _ = createTestNodesSetup(argsTestNodesSetup{ + shardConsensusSize: 1, + shardMinNodes: 1, + metaConsensusSize: 1, + metaMinNodes: 1, + numInitialNodes: 3, + genesisMaxShards: 1, + }) + require.False(t, ns.IsInterfaceNil()) } diff --git a/sharding/oneShardCoordinator_test.go b/sharding/oneShardCoordinator_test.go new file mode 100644 index 00000000000..c2c5d68edfe --- /dev/null +++ b/sharding/oneShardCoordinator_test.go @@ -0,0 +1,33 @@ +package sharding + +import ( + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/stretchr/testify/require" +) + +func TestOneShardCoordinator_NumberOfShardsShouldWork(t *testing.T) { + t.Parallel() + + oneShardCoordinator := OneShardCoordinator{} + + returnedVal := oneShardCoordinator.NumberOfShards() + require.Equal(t, uint32(1), returnedVal) + + returnedVal = oneShardCoordinator.ComputeId([]byte{}) + require.Equal(t, uint32(0), returnedVal) + + returnedVal = oneShardCoordinator.SelfId() + require.Equal(t, uint32(0), returnedVal) + + isShameShard := oneShardCoordinator.SameShard(nil, nil) + require.True(t, isShameShard) + + communicationID := oneShardCoordinator.CommunicationIdentifier(0) + require.Equal(t, core.CommunicationIdentifierBetweenShards(0, 0), communicationID) + + isInterfaceNil := oneShardCoordinator.IsInterfaceNil() + require.False(t, isInterfaceNil) + +} diff --git a/state/accounts/userAccount.go b/state/accounts/userAccount.go index d626f024559..4d7d280fdcf 100644 --- a/state/accounts/userAccount.go +++ b/state/accounts/userAccount.go @@ -210,6 +210,11 @@ func (a *userAccount) AccountDataHandler() vmcommon.AccountDataHandler { return a.dataTrieInteractor } +// GetLeavesParser returns the leaves parser +func (a *userAccount) GetLeavesParser() common.TrieLeafParser { + return a.dataTrieLeafParser +} + // IsInterfaceNil returns true if there is no value under the interface func (a *userAccount) IsInterfaceNil() bool { return a == nil diff --git a/state/syncer/baseAccoutnsSyncer_test.go b/state/syncer/baseAccoutnsSyncer_test.go index da3819b05ce..e2fcf5336f0 100644 --- a/state/syncer/baseAccoutnsSyncer_test.go +++ b/state/syncer/baseAccoutnsSyncer_test.go @@ -4,15 +4,17 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/syncer" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/testscommon/storageManager" - "github.com/stretchr/testify/require" ) func getDefaultBaseAccSyncerArgs() syncer.ArgsNewBaseAccountsSyncer { @@ -22,7 +24,7 @@ func getDefaultBaseAccSyncerArgs() syncer.ArgsNewBaseAccountsSyncer { TrieStorageManager: &storageManager.StorageManagerStub{}, RequestHandler: &testscommon.RequestHandlerStub{}, Timeout: time.Second, - Cacher: testscommon.NewCacherMock(), + Cacher: cache.NewCacherMock(), UserAccountsSyncStatisticsHandler: &testscommon.SizeSyncStatisticsHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, diff --git a/state/syncer/userAccountSyncer_test.go b/state/syncer/userAccountSyncer_test.go index eefdd96778f..3ecdf5cd178 100644 --- a/state/syncer/userAccountSyncer_test.go +++ b/state/syncer/userAccountSyncer_test.go @@ -4,15 +4,17 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/dataRetriever/mock" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/testscommon/storageManager" "github.com/multiversx/mx-chain-go/trie" - "github.com/stretchr/testify/assert" ) // TODO add more tests @@ -24,7 +26,7 @@ func getDefaultBaseAccSyncerArgs() ArgsNewBaseAccountsSyncer { TrieStorageManager: &storageManager.StorageManagerStub{}, RequestHandler: &testscommon.RequestHandlerStub{}, Timeout: time.Second, - Cacher: testscommon.NewCacherMock(), + Cacher: cache.NewCacherMock(), UserAccountsSyncStatisticsHandler: &testscommon.SizeSyncStatisticsHandlerStub{}, AppStatusHandler: &statusHandler.AppStatusHandlerStub{}, MaxTrieLevelInMemory: 0, @@ -95,7 +97,7 @@ func TestUserAccountsSyncer_MissingDataTrieNodeFound(t *testing.T) { rootHash, _ := tr.RootHash() _ = tr.Commit() - args.Cacher = &testscommon.CacherStub{ + args.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { interceptedNode, _ := trie.NewInterceptedTrieNode(serializedLeafNode, args.Hasher) return interceptedNode, true diff --git a/state/syncer/userAccountsSyncer_test.go b/state/syncer/userAccountsSyncer_test.go index 176a4ec7497..5d7252d3b2e 100644 --- a/state/syncer/userAccountsSyncer_test.go +++ b/state/syncer/userAccountsSyncer_test.go @@ -10,6 +10,9 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/api/mock" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" @@ -20,14 +23,13 @@ import ( "github.com/multiversx/mx-chain-go/state/parsers" "github.com/multiversx/mx-chain-go/state/syncer" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/storageManager" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/storageMarker" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func getDefaultUserAccountsSyncerArgs() syncer.ArgsNewUserAccountsSyncer { @@ -148,7 +150,7 @@ func TestUserAccountsSyncer_SyncAccounts(t *testing.T) { }, } - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() cacher.Put(key, itn, 0) args.Cacher = cacher @@ -228,7 +230,7 @@ func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { }, } - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() cacher.Put(key, itn, 0) args.Cacher = cacher @@ -285,7 +287,7 @@ func TestUserAccountsSyncer_SyncAccountDataTries(t *testing.T) { }, } - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() cacher.Put(key, itn, 0) args.Cacher = cacher @@ -366,7 +368,7 @@ func TestUserAccountsSyncer_MissingDataTrieNodeFound(t *testing.T) { rootHash, _ := tr.RootHash() _ = tr.Commit() - args.Cacher = &testscommon.CacherStub{ + args.Cacher = &cache.CacherStub{ GetCalled: func(key []byte) (value interface{}, ok bool) { interceptedNode, _ := trie.NewInterceptedTrieNode(serializedLeafNode, args.Hasher) return interceptedNode, true diff --git a/state/syncer/validatorAccountsSyncer_test.go b/state/syncer/validatorAccountsSyncer_test.go index b4a025883f1..1ba90712704 100644 --- a/state/syncer/validatorAccountsSyncer_test.go +++ b/state/syncer/validatorAccountsSyncer_test.go @@ -4,15 +4,16 @@ import ( "errors" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/state/syncer" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/storageManager" "github.com/multiversx/mx-chain-go/trie" "github.com/multiversx/mx-chain-go/trie/storageMarker" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestNewValidatorAccountsSyncer(t *testing.T) { @@ -93,7 +94,7 @@ func TestValidatorAccountsSyncer_SyncAccounts(t *testing.T) { }, } - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() cacher.Put(key, itn, 0) args.Cacher = cacher diff --git a/statusHandler/statusMetricsProvider.go b/statusHandler/statusMetricsProvider.go index f602a126554..17fd6067aef 100644 --- a/statusHandler/statusMetricsProvider.go +++ b/statusHandler/statusMetricsProvider.go @@ -382,6 +382,7 @@ func (sm *statusMetrics) EnableEpochsMetrics() (map[string]interface{}, error) { enableEpochsMetrics[common.MetricFixRelayedMoveBalanceToNonPayableSCEnableEpoch] = sm.uint64Metrics[common.MetricFixRelayedMoveBalanceToNonPayableSCEnableEpoch] enableEpochsMetrics[common.MetricRelayedTransactionsV3EnableEpoch] = sm.uint64Metrics[common.MetricRelayedTransactionsV3EnableEpoch] enableEpochsMetrics[common.MetricRelayedTransactionsV3FixESDTTransferEnableEpoch] = sm.uint64Metrics[common.MetricRelayedTransactionsV3FixESDTTransferEnableEpoch] + enableEpochsMetrics[common.MetricCheckBuiltInCallOnTransferValueAndFailEnableRound] = sm.uint64Metrics[common.MetricCheckBuiltInCallOnTransferValueAndFailEnableRound] enableEpochsMetrics[common.MetricMaskVMInternalDependenciesErrorsEnableEpoch] = sm.uint64Metrics[common.MetricMaskVMInternalDependenciesErrorsEnableEpoch] enableEpochsMetrics[common.MetricFixBackTransferOPCODEEnableEpoch] = sm.uint64Metrics[common.MetricFixBackTransferOPCODEEnableEpoch] enableEpochsMetrics[common.MetricValidationOnGobDecodeEnableEpoch] = sm.uint64Metrics[common.MetricValidationOnGobDecodeEnableEpoch] diff --git a/statusHandler/statusMetricsProvider_test.go b/statusHandler/statusMetricsProvider_test.go index 89941a0f676..776e8a766d2 100644 --- a/statusHandler/statusMetricsProvider_test.go +++ b/statusHandler/statusMetricsProvider_test.go @@ -406,6 +406,7 @@ func TestStatusMetrics_EnableEpochMetrics(t *testing.T) { sm.SetUInt64Value(common.MetricFixRelayedMoveBalanceToNonPayableSCEnableEpoch, uint64(4)) sm.SetUInt64Value(common.MetricRelayedTransactionsV3EnableEpoch, uint64(4)) sm.SetUInt64Value(common.MetricRelayedTransactionsV3FixESDTTransferEnableEpoch, uint64(4)) + sm.SetUInt64Value(common.MetricCheckBuiltInCallOnTransferValueAndFailEnableRound, uint64(4)) sm.SetUInt64Value(common.MetricMaskVMInternalDependenciesErrorsEnableEpoch, uint64(4)) sm.SetUInt64Value(common.MetricFixBackTransferOPCODEEnableEpoch, uint64(4)) sm.SetUInt64Value(common.MetricValidationOnGobDecodeEnableEpoch, uint64(4)) @@ -544,6 +545,7 @@ func TestStatusMetrics_EnableEpochMetrics(t *testing.T) { common.MetricFixRelayedMoveBalanceToNonPayableSCEnableEpoch: uint64(4), common.MetricRelayedTransactionsV3EnableEpoch: uint64(4), common.MetricRelayedTransactionsV3FixESDTTransferEnableEpoch: uint64(4), + common.MetricCheckBuiltInCallOnTransferValueAndFailEnableRound: uint64(4), common.MetricMaskVMInternalDependenciesErrorsEnableEpoch: uint64(4), common.MetricFixBackTransferOPCODEEnableEpoch: uint64(4), common.MetricValidationOnGobDecodeEnableEpoch: uint64(4), diff --git a/storage/factory/storageServiceFactory.go b/storage/factory/storageServiceFactory.go index c153e6b2cc8..d83c1088d16 100644 --- a/storage/factory/storageServiceFactory.go +++ b/storage/factory/storageServiceFactory.go @@ -235,6 +235,16 @@ func (psf *StorageServiceFactory) createAndAddBaseStorageUnits( } store.AddStorer(dataRetriever.MetaBlockUnit, metaBlockUnit) + proofsUnitArgs, err := psf.createPruningStorerArgs(psf.generalConfig.ProofsStorage, disabledCustomDatabaseRemover) + if err != nil { + return err + } + proofsUnit, err := psf.createPruningPersister(proofsUnitArgs) + if err != nil { + return fmt.Errorf("%w for ProofsStorage", err) + } + store.AddStorer(dataRetriever.ProofsUnit, proofsUnit) + metaHdrHashNonceUnit, err := psf.createStaticStorageUnit(psf.generalConfig.MetaHdrNonceHashStorage, shardID, emptyDBPathSuffix) if err != nil { return fmt.Errorf("%w for MetaHdrNonceHashStorage", err) @@ -338,7 +348,7 @@ func (psf *StorageServiceFactory) CreateForShard() (dataRetriever.StorageService } store.AddStorer(dataRetriever.PeerChangesUnit, peerBlockUnit) - hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(psf.shardCoordinator.SelfId()) + hdrNonceHashDataUnit := dataRetriever.GetHdrNonceHashDataUnit(psf.shardCoordinator.SelfId()) store.AddStorer(hdrNonceHashDataUnit, shardHdrHashNonceUnit) err = psf.setUpDbLookupExtensions(store) @@ -392,7 +402,7 @@ func (psf *StorageServiceFactory) CreateForMeta() (dataRetriever.StorageService, store.AddStorer(dataRetriever.PeerAccountsUnit, peerAccountsUnit) for i := uint32(0); i < psf.shardCoordinator.NumberOfShards(); i++ { - hdrNonceHashDataUnit := dataRetriever.ShardHdrNonceHashDataUnit + dataRetriever.UnitType(i) + hdrNonceHashDataUnit := dataRetriever.GetHdrNonceHashDataUnit(i) store.AddStorer(hdrNonceHashDataUnit, shardHdrHashNonceUnits[i]) } diff --git a/storage/factory/storageServiceFactory_test.go b/storage/factory/storageServiceFactory_test.go index e45308f48d2..9f3081337b9 100644 --- a/storage/factory/storageServiceFactory_test.go +++ b/storage/factory/storageServiceFactory_test.go @@ -47,6 +47,7 @@ func createMockArgument(t *testing.T) StorageServiceFactoryArgs { StatusMetricsStorage: createMockStorageConfig("StatusMetricsStorage"), PeerBlockBodyStorage: createMockStorageConfig("PeerBlockBodyStorage"), TrieEpochRootHashStorage: createMockStorageConfig("TrieEpochRootHashStorage"), + ProofsStorage: createMockStorageConfig("ProofsStorage"), DbLookupExtensions: config.DbLookupExtensionsConfig{ Enabled: true, DbLookupMaxActivePersisters: 10, @@ -408,7 +409,7 @@ func TestStorageServiceFactory_CreateForShard(t *testing.T) { assert.Nil(t, err) assert.False(t, check.IfNil(storageService)) allStorers := storageService.GetAllStorers() - expectedStorers := 23 + expectedStorers := 24 assert.Equal(t, expectedStorers, len(allStorers)) storer, _ := storageService.GetStorer(dataRetriever.UserAccountsUnit) @@ -430,7 +431,7 @@ func TestStorageServiceFactory_CreateForShard(t *testing.T) { assert.False(t, check.IfNil(storageService)) allStorers := storageService.GetAllStorers() numDBLookupExtensionUnits := 6 - expectedStorers := 23 - numDBLookupExtensionUnits + expectedStorers := 24 - numDBLookupExtensionUnits assert.Equal(t, expectedStorers, len(allStorers)) _ = storageService.CloseAll() }) @@ -444,7 +445,7 @@ func TestStorageServiceFactory_CreateForShard(t *testing.T) { assert.Nil(t, err) assert.False(t, check.IfNil(storageService)) allStorers := storageService.GetAllStorers() - expectedStorers := 23 // we still have a storer for trie epoch root hash + expectedStorers := 24 // we still have a storer for trie epoch root hash assert.Equal(t, expectedStorers, len(allStorers)) _ = storageService.CloseAll() }) @@ -458,7 +459,7 @@ func TestStorageServiceFactory_CreateForShard(t *testing.T) { assert.Nil(t, err) assert.False(t, check.IfNil(storageService)) allStorers := storageService.GetAllStorers() - expectedStorers := 23 + expectedStorers := 24 assert.Equal(t, expectedStorers, len(allStorers)) storer, _ := storageService.GetStorer(dataRetriever.UserAccountsUnit) @@ -527,7 +528,7 @@ func TestStorageServiceFactory_CreateForMeta(t *testing.T) { allStorers := storageService.GetAllStorers() missingStorers := 2 // PeerChangesUnit and ShardHdrNonceHashDataUnit numShardHdrStorage := 3 - expectedStorers := 23 - missingStorers + numShardHdrStorage + expectedStorers := 24 - missingStorers + numShardHdrStorage assert.Equal(t, expectedStorers, len(allStorers)) storer, _ := storageService.GetStorer(dataRetriever.UserAccountsUnit) @@ -550,7 +551,7 @@ func TestStorageServiceFactory_CreateForMeta(t *testing.T) { allStorers := storageService.GetAllStorers() missingStorers := 2 // PeerChangesUnit and ShardHdrNonceHashDataUnit numShardHdrStorage := 3 - expectedStorers := 23 - missingStorers + numShardHdrStorage + expectedStorers := 24 - missingStorers + numShardHdrStorage assert.Equal(t, expectedStorers, len(allStorers)) storer, _ := storageService.GetStorer(dataRetriever.UserAccountsUnit) diff --git a/storage/pruning/fullHistoryPruningStorer.go b/storage/pruning/fullHistoryPruningStorer.go index 97852aa3bcd..eeabefd4a7b 100644 --- a/storage/pruning/fullHistoryPruningStorer.go +++ b/storage/pruning/fullHistoryPruningStorer.go @@ -113,7 +113,7 @@ func (fhps *FullHistoryPruningStorer) PutInEpoch(key []byte, data []byte, epoch return err } - return fhps.doPutInPersister(key, data, persister) + return fhps.doPutInPersister(key, data, persister, epoch) } func (fhps *FullHistoryPruningStorer) searchInEpoch(key []byte, epoch uint32) ([]byte, error) { diff --git a/storage/pruning/pruningStorer.go b/storage/pruning/pruningStorer.go index d40680e5c87..0238f987e1d 100644 --- a/storage/pruning/pruningStorer.go +++ b/storage/pruning/pruningStorer.go @@ -323,7 +323,7 @@ func (ps *PruningStorer) Put(key, data []byte) error { persisterToUse := ps.getPersisterToUse() - return ps.doPutInPersister(key, data, persisterToUse.getPersister()) + return ps.doPutInPersister(key, data, persisterToUse.getPersister(), persisterToUse.epoch) } func (ps *PruningStorer) getPersisterToUse() *persisterData { @@ -358,13 +358,15 @@ func (ps *PruningStorer) getPersisterToUse() *persisterData { return persisterToUse } -func (ps *PruningStorer) doPutInPersister(key, data []byte, persister storage.Persister) error { +func (ps *PruningStorer) doPutInPersister(key, data []byte, persister storage.Persister, epoch uint32) error { err := persister.Put(key, data) if err != nil { ps.cacher.Remove(key) return err } + ps.stateStatsHandler.IncrWritePersister(epoch) + return nil } @@ -385,7 +387,7 @@ func (ps *PruningStorer) PutInEpoch(key, data []byte, epoch uint32) error { } defer closePersister() - return ps.doPutInPersister(key, data, persister) + return ps.doPutInPersister(key, data, persister, epoch) } func (ps *PruningStorer) createAndInitPersisterIfClosedProtected(pd *persisterData) (storage.Persister, func(), error) { @@ -434,7 +436,7 @@ func (ps *PruningStorer) createAndInitPersister(pd *persisterData) (storage.Pers func (ps *PruningStorer) Get(key []byte) ([]byte, error) { v, ok := ps.cacher.Get(key) if ok { - ps.stateStatsHandler.IncrementCache() + ps.stateStatsHandler.IncrCache() return v.([]byte), nil } @@ -445,6 +447,8 @@ func (ps *PruningStorer) Get(key []byte) ([]byte, error) { numClosedDbs := 0 for idx := 0; idx < len(ps.activePersisters); idx++ { + ps.stateStatsHandler.IncrPersister(ps.activePersisters[idx].epoch) + val, err := ps.activePersisters[idx].persister.Get(key) if err != nil { if errors.Is(err, storage.ErrDBIsClosed) { @@ -457,8 +461,6 @@ func (ps *PruningStorer) Get(key []byte) ([]byte, error) { // if found in persistence unit, add it to cache and return _ = ps.cacher.Put(key, val, len(val)) - ps.stateStatsHandler.IncrementPersister(ps.activePersisters[idx].epoch) - return val, nil } @@ -499,6 +501,7 @@ func (ps *PruningStorer) GetFromEpoch(key []byte, epoch uint32) ([]byte, error) // TODO: this will be used when requesting from resolvers v, ok := ps.cacher.Get(key) if ok { + ps.stateStatsHandler.IncrCache() return v.([]byte), nil } @@ -516,6 +519,8 @@ func (ps *PruningStorer) GetFromEpoch(key []byte, epoch uint32) ([]byte, error) } defer closePersister() + ps.stateStatsHandler.IncrPersister(pd.epoch) + res, err := persister.Get(key) if err == nil { return res, nil @@ -555,9 +560,11 @@ func (ps *PruningStorer) GetBulkFromEpoch(keys [][]byte, epoch uint32) ([]data.K if ok { keyValue := data.KeyValuePair{Key: key, Value: v.([]byte)} results = append(results, keyValue) + ps.stateStatsHandler.IncrCache() continue } + ps.stateStatsHandler.IncrPersister(pd.epoch) res, errGet := persisterToRead.Get(key) if errGet != nil { log.Warn("cannot get from persister", @@ -578,6 +585,7 @@ func (ps *PruningStorer) GetBulkFromEpoch(keys [][]byte, epoch uint32) ([]data.K func (ps *PruningStorer) SearchFirst(key []byte) ([]byte, error) { v, ok := ps.cacher.Get(key) if ok { + ps.stateStatsHandler.IncrCache() return v.([]byte), nil } @@ -587,6 +595,8 @@ func (ps *PruningStorer) SearchFirst(key []byte) ([]byte, error) { ps.lock.RLock() defer ps.lock.RUnlock() for _, pd := range ps.activePersisters { + ps.stateStatsHandler.IncrPersister(pd.epoch) + res, err = pd.getPersister().Get(key) if err == nil { return res, nil @@ -606,6 +616,7 @@ func (ps *PruningStorer) SearchFirst(key []byte) ([]byte, error) { func (ps *PruningStorer) Has(key []byte) error { has := ps.cacher.Has(key) if has { + ps.stateStatsHandler.IncrCache() return nil } @@ -613,6 +624,7 @@ func (ps *PruningStorer) Has(key []byte) error { defer ps.lock.RUnlock() for _, persister := range ps.activePersisters { + ps.stateStatsHandler.IncrPersister(persister.epoch) if persister.getPersister().Has(key) != nil { continue } @@ -642,6 +654,7 @@ func (ps *PruningStorer) RemoveFromCurrentEpoch(key []byte) error { persisterToUse := ps.activePersisters[0] + ps.stateStatsHandler.IncrWritePersister(persisterToUse.epoch) return persisterToUse.persister.Remove(key) } @@ -653,6 +666,7 @@ func (ps *PruningStorer) Remove(key []byte) error { ps.lock.RLock() defer ps.lock.RUnlock() for _, pd := range ps.activePersisters { + ps.stateStatsHandler.IncrWritePersister(pd.epoch) err = pd.persister.Remove(key) if err == nil { return nil diff --git a/storage/pruning/triePruningStorer.go b/storage/pruning/triePruningStorer.go index e7707088689..a52b2483e84 100644 --- a/storage/pruning/triePruningStorer.go +++ b/storage/pruning/triePruningStorer.go @@ -95,7 +95,7 @@ func (ps *triePruningStorer) PutInEpochWithoutCache(key []byte, data []byte, epo func (ps *triePruningStorer) GetFromOldEpochsWithoutAddingToCache(key []byte, maxEpochToSearchFrom uint32) ([]byte, core.OptionalUint32, error) { v, ok := ps.cacher.Get(key) if ok && !bytes.Equal([]byte(common.ActiveDBKey), key) { - ps.stateStatsHandler.IncrementSnapshotCache() + ps.stateStatsHandler.IncrSnapshotCache() return v.([]byte), core.OptionalUint32{}, nil } @@ -121,7 +121,7 @@ func (ps *triePruningStorer) GetFromOldEpochsWithoutAddingToCache(key []byte, ma HasValue: true, } - ps.stateStatsHandler.IncrementSnapshotPersister(epoch.Value) + ps.stateStatsHandler.IncrSnapshotPersister(epoch.Value) return val, epoch, nil } diff --git a/storage/pruning/triePruningStorer_test.go b/storage/pruning/triePruningStorer_test.go index d59a86d5187..e6a699f1329 100644 --- a/storage/pruning/triePruningStorer_test.go +++ b/storage/pruning/triePruningStorer_test.go @@ -8,7 +8,8 @@ import ( "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/storage/mock" "github.com/multiversx/mx-chain-go/storage/pruning" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -44,7 +45,7 @@ func TestTriePruningStorer_GetFromOldEpochsWithoutCacheSearchesOnlyOldEpochsAndR args := getDefaultArgs() ps, _ := pruning.NewTriePruningStorer(args) - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() ps.SetCacher(cacher) testKey1 := []byte("key1") @@ -81,7 +82,7 @@ func TestTriePruningStorer_GetFromOldEpochsWithCache(t *testing.T) { args := getDefaultArgs() ps, _ := pruning.NewTriePruningStorer(args) - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() ps.SetCacher(cacher) testKey1 := []byte("key1") @@ -185,7 +186,7 @@ func TestTriePruningStorer_GetFromOldEpochsWithoutCacheDoesNotSearchInCurrentSto args := getDefaultArgs() ps, _ := pruning.NewTriePruningStorer(args) - cacher := testscommon.NewCacherStub() + cacher := cache.NewCacherStub() cacher.PutCalled = func(_ []byte, _ interface{}, _ int) bool { require.Fail(t, "this should not be called") return false @@ -209,7 +210,7 @@ func TestTriePruningStorer_GetFromLastEpochSearchesOnlyLastEpoch(t *testing.T) { args := getDefaultArgs() ps, _ := pruning.NewTriePruningStorer(args) - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() ps.SetCacher(cacher) testKey1 := []byte("key1") @@ -258,7 +259,7 @@ func TestTriePruningStorer_GetFromCurrentEpochSearchesOnlyCurrentEpoch(t *testin args := getDefaultArgs() ps, _ := pruning.NewTriePruningStorer(args) - cacher := testscommon.NewCacherMock() + cacher := cache.NewCacherMock() ps.SetCacher(cacher) testKey1 := []byte("key1") diff --git a/storage/storageunit/storageunit_test.go b/storage/storageunit/storageunit_test.go index da4aea63b33..f92d70a48f7 100644 --- a/storage/storageunit/storageunit_test.go +++ b/storage/storageunit/storageunit_test.go @@ -5,21 +5,22 @@ import ( "testing" "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-storage-go/common" + "github.com/stretchr/testify/assert" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/storage/factory" "github.com/multiversx/mx-chain-go/storage/mock" "github.com/multiversx/mx-chain-go/storage/storageunit" - "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" "github.com/multiversx/mx-chain-go/testscommon/storage" - "github.com/multiversx/mx-chain-storage-go/common" - "github.com/stretchr/testify/assert" ) func TestNewStorageUnit(t *testing.T) { t.Parallel() - cacher := &testscommon.CacherStub{} + cacher := &cache.CacherStub{} persister := &mock.PersisterStub{} t.Run("nil cacher should error", func(t *testing.T) { diff --git a/testscommon/blockChainHookStub.go b/testscommon/blockChainHookStub.go index 6412881b77e..8ffb74d2729 100644 --- a/testscommon/blockChainHookStub.go +++ b/testscommon/blockChainHookStub.go @@ -325,7 +325,7 @@ func (stub *BlockChainHookStub) NumberOfShards() uint32 { // SetCurrentHeader - func (stub *BlockChainHookStub) SetCurrentHeader(hdr data.HeaderHandler) error { if stub.SetCurrentHeaderCalled != nil { - stub.SetCurrentHeaderCalled(hdr) + return stub.SetCurrentHeaderCalled(hdr) } return nil diff --git a/consensus/mock/bootstrapperStub.go b/testscommon/bootstrapperStubs/bootstrapperStub.go similarity index 98% rename from consensus/mock/bootstrapperStub.go rename to testscommon/bootstrapperStubs/bootstrapperStub.go index bd4a1b98bf2..346656e1b8e 100644 --- a/consensus/mock/bootstrapperStub.go +++ b/testscommon/bootstrapperStubs/bootstrapperStub.go @@ -1,8 +1,9 @@ -package mock +package bootstrapperStubs import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/multiversx/mx-chain-go/common" ) diff --git a/testscommon/cacherMock.go b/testscommon/cache/cacherMock.go similarity index 99% rename from testscommon/cacherMock.go rename to testscommon/cache/cacherMock.go index 0b1a9aa5edf..4b569d34375 100644 --- a/testscommon/cacherMock.go +++ b/testscommon/cache/cacherMock.go @@ -1,4 +1,4 @@ -package testscommon +package cache import ( "sync" diff --git a/testscommon/cacherStub.go b/testscommon/cache/cacherStub.go similarity index 99% rename from testscommon/cacherStub.go rename to testscommon/cache/cacherStub.go index e3e11dd811f..82e30610563 100644 --- a/testscommon/cacherStub.go +++ b/testscommon/cache/cacherStub.go @@ -1,4 +1,4 @@ -package testscommon +package cache // CacherStub - type CacherStub struct { diff --git a/testscommon/chainParameters/chainParametersHolderMock.go b/testscommon/chainParameters/chainParametersHolderMock.go new file mode 100644 index 00000000000..50721aa5716 --- /dev/null +++ b/testscommon/chainParameters/chainParametersHolderMock.go @@ -0,0 +1,42 @@ +package chainParameters + +import ( + "github.com/multiversx/mx-chain-go/config" +) + +var testChainParams = config.ChainParametersByEpochConfig{ + RoundDuration: 6000, + Hysteresis: 0, + EnableEpoch: 0, + ShardConsensusGroupSize: 1, + ShardMinNumNodes: 1, + MetachainConsensusGroupSize: 1, + MetachainMinNumNodes: 1, + Adaptivity: false, +} + +// ChainParametersHolderMock - +type ChainParametersHolderMock struct { +} + +// CurrentChainParameters - +func (c *ChainParametersHolderMock) CurrentChainParameters() config.ChainParametersByEpochConfig { + return testChainParams +} + +// AllChainParameters - +func (c *ChainParametersHolderMock) AllChainParameters() []config.ChainParametersByEpochConfig { + return []config.ChainParametersByEpochConfig{ + testChainParams, + } +} + +// ChainParametersForEpoch - +func (c *ChainParametersHolderMock) ChainParametersForEpoch(_ uint32) (config.ChainParametersByEpochConfig, error) { + return testChainParams, nil +} + +// IsInterfaceNil - +func (c *ChainParametersHolderMock) IsInterfaceNil() bool { + return c == nil +} diff --git a/testscommon/chainParameters/chainParametersHolderStub.go b/testscommon/chainParameters/chainParametersHolderStub.go new file mode 100644 index 00000000000..6d12bb3fa46 --- /dev/null +++ b/testscommon/chainParameters/chainParametersHolderStub.go @@ -0,0 +1,42 @@ +package chainParameters + +import "github.com/multiversx/mx-chain-go/config" + +// ChainParametersHandlerStub - +type ChainParametersHandlerStub struct { + CurrentChainParametersCalled func() config.ChainParametersByEpochConfig + AllChainParametersCalled func() []config.ChainParametersByEpochConfig + ChainParametersForEpochCalled func(epoch uint32) (config.ChainParametersByEpochConfig, error) +} + +// CurrentChainParameters - +func (stub *ChainParametersHandlerStub) CurrentChainParameters() config.ChainParametersByEpochConfig { + if stub.CurrentChainParametersCalled != nil { + return stub.CurrentChainParametersCalled() + } + + return config.ChainParametersByEpochConfig{} +} + +// AllChainParameters - +func (stub *ChainParametersHandlerStub) AllChainParameters() []config.ChainParametersByEpochConfig { + if stub.AllChainParametersCalled != nil { + return stub.AllChainParametersCalled() + } + + return nil +} + +// ChainParametersForEpoch - +func (stub *ChainParametersHandlerStub) ChainParametersForEpoch(epoch uint32) (config.ChainParametersByEpochConfig, error) { + if stub.ChainParametersForEpochCalled != nil { + return stub.ChainParametersForEpochCalled(epoch) + } + + return config.ChainParametersByEpochConfig{}, nil +} + +// IsInterfaceNil - +func (stub *ChainParametersHandlerStub) IsInterfaceNil() bool { + return stub == nil +} diff --git a/testscommon/chainSimulator/chainSimulatorMock.go b/testscommon/chainSimulator/chainSimulatorMock.go index 07db474a07e..7d726e43e2b 100644 --- a/testscommon/chainSimulator/chainSimulatorMock.go +++ b/testscommon/chainSimulator/chainSimulatorMock.go @@ -1,11 +1,34 @@ package chainSimulator -import "github.com/multiversx/mx-chain-go/node/chainSimulator/process" +import ( + "github.com/multiversx/mx-chain-go/node/chainSimulator/dtos" + "github.com/multiversx/mx-chain-go/node/chainSimulator/process" + "math/big" +) // ChainSimulatorMock - type ChainSimulatorMock struct { - GenerateBlocksCalled func(numOfBlocks int) error - GetNodeHandlerCalled func(shardID uint32) process.NodeHandler + GenerateBlocksCalled func(numOfBlocks int) error + GetNodeHandlerCalled func(shardID uint32) process.NodeHandler + GenerateAddressInShardCalled func(providedShardID uint32) dtos.WalletAddress + GenerateAndMintWalletAddressCalled func(targetShardID uint32, value *big.Int) (dtos.WalletAddress, error) +} + +// GenerateAddressInShard - +func (mock *ChainSimulatorMock) GenerateAddressInShard(providedShardID uint32) dtos.WalletAddress { + if mock.GenerateAddressInShardCalled != nil { + return mock.GenerateAddressInShardCalled(providedShardID) + } + + return dtos.WalletAddress{} +} + +// GenerateAndMintWalletAddress - +func (mock *ChainSimulatorMock) GenerateAndMintWalletAddress(targetShardID uint32, value *big.Int) (dtos.WalletAddress, error) { + if mock.GenerateAndMintWalletAddressCalled != nil { + return mock.GenerateAndMintWalletAddressCalled(targetShardID, value) + } + return dtos.WalletAddress{}, nil } // GenerateBlocks - diff --git a/testscommon/chainSimulator/nodeHandlerMock.go b/testscommon/chainSimulator/nodeHandlerMock.go index 3f306807130..158744fea29 100644 --- a/testscommon/chainSimulator/nodeHandlerMock.go +++ b/testscommon/chainSimulator/nodeHandlerMock.go @@ -1,6 +1,7 @@ package chainSimulator import ( + "github.com/multiversx/mx-chain-core-go/core" chainData "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-go/api/shared" "github.com/multiversx/mx-chain-go/consensus" @@ -21,9 +22,12 @@ type NodeHandlerMock struct { GetStateComponentsCalled func() factory.StateComponentsHolder GetFacadeHandlerCalled func() shared.FacadeHandler GetStatusCoreComponentsCalled func() factory.StatusCoreComponentsHolder + GetNetworkComponentsCalled func() factory.NetworkComponentsHolder SetKeyValueForAddressCalled func(addressBytes []byte, state map[string]string) error SetStateForAddressCalled func(address []byte, state *dtos.AddressState) error RemoveAccountCalled func(address []byte) error + GetBasePeersCalled func() map[uint32]core.PeerID + SetBasePeersCalled func(basePeers map[uint32]core.PeerID) CloseCalled func() error } @@ -112,6 +116,14 @@ func (mock *NodeHandlerMock) GetStatusCoreComponents() factory.StatusCoreCompone return nil } +// GetNetworkComponents - +func (mock *NodeHandlerMock) GetNetworkComponents() factory.NetworkComponentsHolder { + if mock.GetNetworkComponentsCalled != nil { + return mock.GetNetworkComponentsCalled() + } + return nil +} + // SetKeyValueForAddress - func (mock *NodeHandlerMock) SetKeyValueForAddress(addressBytes []byte, state map[string]string) error { if mock.SetKeyValueForAddressCalled != nil { @@ -137,6 +149,22 @@ func (mock *NodeHandlerMock) RemoveAccount(address []byte) error { return nil } +// GetBasePeers - +func (mock *NodeHandlerMock) GetBasePeers() map[uint32]core.PeerID { + if mock.GetBasePeersCalled != nil { + return mock.GetBasePeersCalled() + } + + return nil +} + +// SetBasePeers - +func (mock *NodeHandlerMock) SetBasePeers(basePeers map[uint32]core.PeerID) { + if mock.SetBasePeersCalled != nil { + mock.SetBasePeersCalled(basePeers) + } +} + // Close - func (mock *NodeHandlerMock) Close() error { if mock.CloseCalled != nil { diff --git a/node/mock/throttlerStub.go b/testscommon/common/throttlerStub.go similarity index 98% rename from node/mock/throttlerStub.go rename to testscommon/common/throttlerStub.go index 24ab94c45c3..f4f5e0a34d0 100644 --- a/node/mock/throttlerStub.go +++ b/testscommon/common/throttlerStub.go @@ -1,4 +1,4 @@ -package mock +package common // ThrottlerStub - type ThrottlerStub struct { diff --git a/testscommon/commonmocks/chainParametersNotifierStub.go b/testscommon/commonmocks/chainParametersNotifierStub.go new file mode 100644 index 00000000000..94971a354b5 --- /dev/null +++ b/testscommon/commonmocks/chainParametersNotifierStub.go @@ -0,0 +1,38 @@ +package commonmocks + +import ( + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/config" +) + +// ChainParametersNotifierStub - +type ChainParametersNotifierStub struct { + ChainParametersChangedCalled func(chainParameters config.ChainParametersByEpochConfig) + UpdateCurrentChainParametersCalled func(params config.ChainParametersByEpochConfig) + RegisterNotifyHandlerCalled func(handler common.ChainParametersSubscriptionHandler) +} + +// ChainParametersChanged - +func (c *ChainParametersNotifierStub) ChainParametersChanged(chainParameters config.ChainParametersByEpochConfig) { + if c.ChainParametersChangedCalled != nil { + c.ChainParametersChangedCalled(chainParameters) + } +} + +// UpdateCurrentChainParameters - +func (c *ChainParametersNotifierStub) UpdateCurrentChainParameters(params config.ChainParametersByEpochConfig) { + if c.UpdateCurrentChainParametersCalled != nil { + c.UpdateCurrentChainParametersCalled(params) + } +} + +// RegisterNotifyHandler - +func (c *ChainParametersNotifierStub) RegisterNotifyHandler(handler common.ChainParametersSubscriptionHandler) { + if c.RegisterNotifyHandlerCalled != nil { + c.RegisterNotifyHandlerCalled(handler) + } +} + +func (c *ChainParametersNotifierStub) IsInterfaceNil() bool { + return c == nil +} diff --git a/testscommon/components/components.go b/testscommon/components/components.go index 577d17b6276..19389096b72 100644 --- a/testscommon/components/components.go +++ b/testscommon/components/components.go @@ -9,6 +9,10 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/data/endProcess" "github.com/multiversx/mx-chain-core-go/data/outport" + logger "github.com/multiversx/mx-chain-logger-go" + wasmConfig "github.com/multiversx/mx-chain-vm-go/config" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" commonFactory "github.com/multiversx/mx-chain-go/common/factory" "github.com/multiversx/mx-chain-go/config" @@ -42,9 +46,6 @@ import ( statusHandlerMock "github.com/multiversx/mx-chain-go/testscommon/statusHandler" "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/trie" - logger "github.com/multiversx/mx-chain-logger-go" - wasmConfig "github.com/multiversx/mx-chain-vm-go/config" - "github.com/stretchr/testify/require" ) var log = logger.GetOrCreate("componentsMock") @@ -78,9 +79,25 @@ func GetCoreArgs() coreComp.CoreComponentsFactoryArgs { ConfigPathsHolder: config.ConfigurationPathsHolder{ GasScheduleDirectoryName: "../../cmd/node/config/gasSchedules", }, - RatingsConfig: CreateDummyRatingsConfig(), - EconomicsConfig: CreateDummyEconomicsConfig(), - NodesFilename: "../mock/testdata/nodesSetupMock.json", + RatingsConfig: CreateDummyRatingsConfig(), + EconomicsConfig: CreateDummyEconomicsConfig(), + NodesConfig: config.NodesConfig{ + StartTime: 0, + InitialNodes: []*config.InitialNodeConfig{ + { + PubKey: "227a5a5ec0c58171b7f4ee9ecc304ea7b176fb626741a25c967add76d6cd361d6995929f9b60a96237381091cefb1b061225e5bb930b40494a5ac9d7524fd67dfe478e5ccd80f17b093cff5722025761fb0217c39dbd5ae45e01eb5a3113be93", + Address: "erd1ulhw20j7jvgfgak5p05kv667k5k9f320sgef5ayxkt9784ql0zssrzyhjp", + }, + { + PubKey: "ef9522d654bc08ebf2725468f41a693aa7f3cf1cb93922cff1c8c81fba78274016010916f4a7e5b0855c430a724a2d0b3acd1fe8e61e37273a17d58faa8c0d3ef6b883a33ec648950469a1e9757b978d9ae662a019068a401cff56eea059fd08", + Address: "erd17c4fs6mz2aa2hcvva2jfxdsrdknu4220496jmswer9njznt22eds0rxlr4", + }, + { + PubKey: "e91ab494cedd4da346f47aaa1a3e792bea24fb9f6cc40d3546bc4ca36749b8bfb0164e40dbad2195a76ee0fd7fb7da075ecbf1b35a2ac20638d53ea5520644f8c16952225c48304bb202867e2d71d396bff5a5971f345bcfe32c7b6b0ca34c84", + Address: "erd10d2gufxesrp8g409tzxljlaefhs0rsgjle3l7nq38de59txxt8csj54cd3", + }, + }, + }, WorkingDirectory: "home", ChanStopNodeProcess: make(chan endProcess.ArgEndProcess), EpochConfig: config.EpochConfig{ diff --git a/testscommon/components/configs.go b/testscommon/components/configs.go index 12598af042c..f651d8eab76 100644 --- a/testscommon/components/configs.go +++ b/testscommon/components/configs.go @@ -76,9 +76,9 @@ func GetGeneralConfig() config.Config { {StartEpoch: 0, Version: "v0.3"}, }, TransferAndExecuteByUserAddresses: []string{ - "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", //shard 0 - "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", //shard 1 - "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", //shard 2 + "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", // shard 0 + "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", // shard 1 + "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", // shard 2 }, }, }, @@ -87,9 +87,9 @@ func GetGeneralConfig() config.Config { {StartEpoch: 0, Version: "v0.3"}, }, TransferAndExecuteByUserAddresses: []string{ - "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", //shard 0 - "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", //shard 1 - "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", //shard 2 + "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", // shard 0 + "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", // shard 1 + "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", // shard 2 }, }, GasConfig: config.VirtualMachineGasConfig{ @@ -119,14 +119,14 @@ func GetGeneralConfig() config.Config { }, BuiltInFunctions: config.BuiltInFunctionsConfig{ AutomaticCrawlerAddresses: []string{ - "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", //shard 0 - "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", //shard 1 - "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", //shard 2 + "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", // shard 0 + "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", // shard 1 + "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", // shard 2 }, DNSV2Addresses: []string{ - "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", //shard 0 - "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", //shard 1 - "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", //shard 2 + "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", // shard 0 + "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", // shard 1 + "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", // shard 2 }, MaxNumAddressesInTransferRole: 100, }, @@ -156,6 +156,19 @@ func GetGeneralConfig() config.Config { MinTransactionVersion: 1, GenesisMaxNumberOfShards: 3, SetGuardianEpochsDelay: 20, + ChainParametersByEpoch: []config.ChainParametersByEpochConfig{ + { + EnableEpoch: 0, + RoundDuration: 4000, + ShardConsensusGroupSize: 1, + ShardMinNumNodes: 1, + MetachainConsensusGroupSize: 1, + MetachainMinNumNodes: 1, + Hysteresis: 0, + Adaptivity: false, + }, + }, + EpochChangeGracePeriodByEpoch: []config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}, }, Marshalizer: config.MarshalizerConfig{ Type: TestMarshalizer, @@ -200,6 +213,20 @@ func GetGeneralConfig() config.Config { ResourceStats: config.ResourceStatsConfig{ RefreshIntervalInSec: 1, }, + ProofsStorage: config.StorageConfig{ + Cache: config.CacheConfig{ + Capacity: 10000, + Type: "LRU", + Shards: 1, + }, + DB: config.DBConfig{ + FilePath: "ProofsStorage", + Type: "MemoryDB", + BatchDelaySeconds: 30, + MaxBatchSize: 6, + MaxOpenFiles: 10, + }, + }, } } @@ -276,21 +303,27 @@ func CreateDummyRatingsConfig() config.RatingsConfig { }, }, ShardChain: config.ShardChain{ - RatingSteps: config.RatingSteps{ - HoursToMaxRatingFromStartRating: 2, - ProposerValidatorImportance: 1, - ProposerDecreaseFactor: -4, - ValidatorDecreaseFactor: -4, - ConsecutiveMissedBlocksPenalty: ConsecutiveMissedBlocksPenalty, + RatingStepsByEpoch: []config.RatingSteps{ + { + HoursToMaxRatingFromStartRating: 2, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: ConsecutiveMissedBlocksPenalty, + EnableEpoch: 0, + }, }, }, MetaChain: config.MetaChain{ - RatingSteps: config.RatingSteps{ - HoursToMaxRatingFromStartRating: 2, - ProposerValidatorImportance: 1, - ProposerDecreaseFactor: -4, - ValidatorDecreaseFactor: -4, - ConsecutiveMissedBlocksPenalty: ConsecutiveMissedBlocksPenalty, + RatingStepsByEpoch: []config.RatingSteps{ + { + HoursToMaxRatingFromStartRating: 2, + ProposerValidatorImportance: 1, + ProposerDecreaseFactor: -4, + ValidatorDecreaseFactor: -4, + ConsecutiveMissedBlocksPenalty: ConsecutiveMissedBlocksPenalty, + EnableEpoch: 0, + }, }, }, } diff --git a/testscommon/components/default.go b/testscommon/components/default.go index 2ff81c35a41..0e88eb371ef 100644 --- a/testscommon/components/default.go +++ b/testscommon/components/default.go @@ -133,7 +133,7 @@ func GetDefaultProcessComponents(shardCoordinator sharding.Coordinator) *mock.Pr BlockProcess: &testscommon.BlockProcessorStub{}, BlackListHdl: &testscommon.TimeCacheStub{}, BootSore: &mock.BootstrapStorerMock{}, - HeaderSigVerif: &mock.HeaderSigVerifierStub{}, + HeaderSigVerif: &consensus.HeaderSigVerifierMock{}, HeaderIntegrVerif: &mock.HeaderIntegrityVerifierStub{}, ValidatorStatistics: &testscommon.ValidatorStatisticsProcessorStub{}, ValidatorProvider: &stakingcommon.ValidatorsProviderStub{}, diff --git a/consensus/mock/broadcastMessangerMock.go b/testscommon/consensus/broadcastMessangerMock.go similarity index 60% rename from consensus/mock/broadcastMessangerMock.go rename to testscommon/consensus/broadcastMessangerMock.go index 2d659490725..ee0c5f5acad 100644 --- a/consensus/mock/broadcastMessangerMock.go +++ b/testscommon/consensus/broadcastMessangerMock.go @@ -1,4 +1,4 @@ -package mock +package consensus import ( "github.com/multiversx/mx-chain-core-go/data" @@ -7,14 +7,16 @@ import ( // BroadcastMessengerMock - type BroadcastMessengerMock struct { - BroadcastBlockCalled func(data.BodyHandler, data.HeaderHandler) error - BroadcastHeaderCalled func(data.HeaderHandler, []byte) error - PrepareBroadcastBlockDataValidatorCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, idx int, pkBytes []byte) error - PrepareBroadcastHeaderValidatorCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, idx int, pkBytes []byte) - BroadcastMiniBlocksCalled func(map[uint32][]byte, []byte) error - BroadcastTransactionsCalled func(map[string][][]byte, []byte) error - BroadcastConsensusMessageCalled func(*consensus.Message) error - BroadcastBlockDataLeaderCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, pkBytes []byte) error + BroadcastBlockCalled func(data.BodyHandler, data.HeaderHandler) error + BroadcastHeaderCalled func(data.HeaderHandler, []byte) error + BroadcastEquivalentProofCalled func(proof data.HeaderProofHandler, pkBytes []byte) error + PrepareBroadcastBlockDataValidatorCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, idx int, pkBytes []byte) error + PrepareBroadcastBlockDataWithEquivalentProofsCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, pkBytes []byte) + PrepareBroadcastHeaderValidatorCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, idx int, pkBytes []byte) + BroadcastMiniBlocksCalled func(map[uint32][]byte, []byte) error + BroadcastTransactionsCalled func(map[string][][]byte, []byte) error + BroadcastConsensusMessageCalled func(*consensus.Message) error + BroadcastBlockDataLeaderCalled func(h data.HeaderHandler, mbs map[uint32][]byte, txs map[string][][]byte, pkBytes []byte) error } // BroadcastBlock - @@ -71,6 +73,23 @@ func (bmm *BroadcastMessengerMock) PrepareBroadcastBlockDataValidator( } } +// PrepareBroadcastBlockDataWithEquivalentProofs - +func (bmm *BroadcastMessengerMock) PrepareBroadcastBlockDataWithEquivalentProofs( + header data.HeaderHandler, + miniBlocks map[uint32][]byte, + transactions map[string][][]byte, + pkBytes []byte, +) { + if bmm.PrepareBroadcastBlockDataWithEquivalentProofsCalled != nil { + bmm.PrepareBroadcastBlockDataWithEquivalentProofsCalled( + header, + miniBlocks, + transactions, + pkBytes, + ) + } +} + // PrepareBroadcastHeaderValidator - func (bmm *BroadcastMessengerMock) PrepareBroadcastHeaderValidator( header data.HeaderHandler, @@ -114,6 +133,14 @@ func (bmm *BroadcastMessengerMock) BroadcastHeader(headerhandler data.HeaderHand return nil } +// BroadcastEquivalentProof - +func (bmm *BroadcastMessengerMock) BroadcastEquivalentProof(proof data.HeaderProofHandler, pkBytes []byte) error { + if bmm.BroadcastEquivalentProofCalled != nil { + return bmm.BroadcastEquivalentProofCalled(proof, pkBytes) + } + return nil +} + // IsInterfaceNil returns true if there is no value under the interface func (bmm *BroadcastMessengerMock) IsInterfaceNil() bool { return bmm == nil diff --git a/consensus/mock/chronologyHandlerMock.go b/testscommon/consensus/chronologyHandlerMock.go similarity index 98% rename from consensus/mock/chronologyHandlerMock.go rename to testscommon/consensus/chronologyHandlerMock.go index 789387845de..0cfceca2eb9 100644 --- a/consensus/mock/chronologyHandlerMock.go +++ b/testscommon/consensus/chronologyHandlerMock.go @@ -1,4 +1,4 @@ -package mock +package consensus import ( "github.com/multiversx/mx-chain-go/consensus" diff --git a/testscommon/consensus/consensusStateMock.go b/testscommon/consensus/consensusStateMock.go new file mode 100644 index 00000000000..587abc6b5d8 --- /dev/null +++ b/testscommon/consensus/consensusStateMock.go @@ -0,0 +1,660 @@ +package consensus + +import ( + "time" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/data" + + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/p2p" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" +) + +// ConsensusStateMock - +type ConsensusStateMock struct { + ResetConsensusStateCalled func() + IsNodeLeaderInCurrentRoundCalled func(node string) bool + IsSelfLeaderInCurrentRoundCalled func() bool + GetLeaderCalled func() (string, error) + GetNextConsensusGroupCalled func(randomSource []byte, round uint64, shardId uint32, nodesCoordinator nodesCoordinator.NodesCoordinator, epoch uint32) (string, []string, error) + IsConsensusDataSetCalled func() bool + IsConsensusDataEqualCalled func(data []byte) bool + IsJobDoneCalled func(node string, currentSubroundId int) bool + IsSelfJobDoneCalled func(currentSubroundId int) bool + IsCurrentSubroundFinishedCalled func(currentSubroundId int) bool + IsNodeSelfCalled func(node string) bool + IsBlockBodyAlreadyReceivedCalled func() bool + IsHeaderAlreadyReceivedCalled func() bool + CanDoSubroundJobCalled func(currentSubroundId int) bool + CanProcessReceivedMessageCalled func(cnsDta *consensus.Message, currentRoundIndex int64, currentSubroundId int) bool + GenerateBitmapCalled func(subroundId int) []byte + ProcessingBlockCalled func() bool + SetProcessingBlockCalled func(processingBlock bool) + ConsensusGroupSizeCalled func() int + SetThresholdCalled func(subroundId int, threshold int) + AddReceivedHeaderCalled func(headerHandler data.HeaderHandler) + GetReceivedHeadersCalled func() []data.HeaderHandler + AddMessageWithSignatureCalled func(key string, message p2p.MessageP2P) + GetMessageWithSignatureCalled func(key string) (p2p.MessageP2P, bool) + IsSubroundFinishedCalled func(subroundID int) bool + GetDataCalled func() []byte + SetDataCalled func(data []byte) + IsMultiKeyLeaderInCurrentRoundCalled func() bool + IsLeaderJobDoneCalled func(currentSubroundId int) bool + IsMultiKeyJobDoneCalled func(currentSubroundId int) bool + GetMultikeyRedundancyStepInReasonCalled func() string + ResetRoundsWithoutReceivedMessagesCalled func(pkBytes []byte, pid core.PeerID) + GetRoundCanceledCalled func() bool + SetRoundCanceledCalled func(state bool) + GetRoundIndexCalled func() int64 + SetRoundIndexCalled func(roundIndex int64) + GetRoundTimeStampCalled func() time.Time + SetRoundTimeStampCalled func(roundTimeStamp time.Time) + GetExtendedCalledCalled func() bool + GetBodyCalled func() data.BodyHandler + SetBodyCalled func(body data.BodyHandler) + GetHeaderCalled func() data.HeaderHandler + SetHeaderCalled func(header data.HeaderHandler) + GetWaitingAllSignaturesTimeOutCalled func() bool + SetWaitingAllSignaturesTimeOutCalled func(b bool) + ConsensusGroupIndexCalled func(pubKey string) (int, error) + SelfConsensusGroupIndexCalled func() (int, error) + SetEligibleListCalled func(eligibleList map[string]struct{}) + ConsensusGroupCalled func() []string + SetConsensusGroupCalled func(consensusGroup []string) + SetLeaderCalled func(leader string) + SetConsensusGroupSizeCalled func(consensusGroupSize int) + SelfPubKeyCalled func() string + SetSelfPubKeyCalled func(selfPubKey string) + JobDoneCalled func(key string, subroundId int) (bool, error) + SetJobDoneCalled func(key string, subroundId int, value bool) error + SelfJobDoneCalled func(subroundId int) (bool, error) + IsNodeInConsensusGroupCalled func(node string) bool + IsNodeInEligibleListCalled func(node string) bool + ComputeSizeCalled func(subroundId int) int + ResetRoundStateCalled func() + IsMultiKeyInConsensusGroupCalled func() bool + IsKeyManagedBySelfCalled func(pkBytes []byte) bool + IncrementRoundsWithoutReceivedMessagesCalled func(pkBytes []byte) + GetKeysHandlerCalled func() consensus.KeysHandler + LeaderCalled func() string + StatusCalled func(subroundId int) int + SetStatusCalled func(subroundId int, subroundStatus int) + ResetRoundStatusCalled func() + ThresholdCalled func(subroundId int) int + FallbackThresholdCalled func(subroundId int) int + SetFallbackThresholdCalled func(subroundId int, threshold int) + ResetConsensusRoundStateCalled func() +} + +// AddReceivedHeader - +func (cnsm *ConsensusStateMock) AddReceivedHeader(headerHandler data.HeaderHandler) { + if cnsm.AddReceivedHeaderCalled != nil { + cnsm.AddReceivedHeaderCalled(headerHandler) + } +} + +// GetReceivedHeaders - +func (cnsm *ConsensusStateMock) GetReceivedHeaders() []data.HeaderHandler { + if cnsm.GetReceivedHeadersCalled != nil { + return cnsm.GetReceivedHeadersCalled() + } + return nil +} + +// AddMessageWithSignature - +func (cnsm *ConsensusStateMock) AddMessageWithSignature(key string, message p2p.MessageP2P) { + if cnsm.AddMessageWithSignatureCalled != nil { + cnsm.AddMessageWithSignatureCalled(key, message) + } +} + +// GetMessageWithSignature - +func (cnsm *ConsensusStateMock) GetMessageWithSignature(key string) (p2p.MessageP2P, bool) { + if cnsm.GetMessageWithSignatureCalled != nil { + return cnsm.GetMessageWithSignatureCalled(key) + } + return nil, false +} + +// IsSubroundFinished - +func (cnsm *ConsensusStateMock) IsSubroundFinished(subroundID int) bool { + if cnsm.IsSubroundFinishedCalled != nil { + return cnsm.IsSubroundFinishedCalled(subroundID) + } + return false +} + +// GetData - +func (cnsm *ConsensusStateMock) GetData() []byte { + if cnsm.GetDataCalled != nil { + return cnsm.GetDataCalled() + } + return nil +} + +// SetData - +func (cnsm *ConsensusStateMock) SetData(data []byte) { + if cnsm.SetDataCalled != nil { + cnsm.SetDataCalled(data) + } +} + +// IsMultiKeyLeaderInCurrentRound - +func (cnsm *ConsensusStateMock) IsMultiKeyLeaderInCurrentRound() bool { + if cnsm.IsMultiKeyLeaderInCurrentRoundCalled != nil { + return cnsm.IsMultiKeyLeaderInCurrentRoundCalled() + } + return false +} + +// IsLeaderJobDone - +func (cnsm *ConsensusStateMock) IsLeaderJobDone(currentSubroundId int) bool { + if cnsm.IsLeaderJobDoneCalled != nil { + return cnsm.IsLeaderJobDoneCalled(currentSubroundId) + } + return false +} + +// IsMultiKeyJobDone - +func (cnsm *ConsensusStateMock) IsMultiKeyJobDone(currentSubroundId int) bool { + if cnsm.IsMultiKeyJobDoneCalled != nil { + return cnsm.IsMultiKeyJobDoneCalled(currentSubroundId) + } + return false +} + +// GetMultikeyRedundancyStepInReason - +func (cnsm *ConsensusStateMock) GetMultikeyRedundancyStepInReason() string { + if cnsm.GetMultikeyRedundancyStepInReasonCalled != nil { + return cnsm.GetMultikeyRedundancyStepInReasonCalled() + } + return "" +} + +// ResetRoundsWithoutReceivedMessages - +func (cnsm *ConsensusStateMock) ResetRoundsWithoutReceivedMessages(pkBytes []byte, pid core.PeerID) { + if cnsm.ResetRoundsWithoutReceivedMessagesCalled != nil { + cnsm.ResetRoundsWithoutReceivedMessagesCalled(pkBytes, pid) + } +} + +// GetRoundCanceled - +func (cnsm *ConsensusStateMock) GetRoundCanceled() bool { + if cnsm.GetRoundCanceledCalled != nil { + return cnsm.GetRoundCanceledCalled() + } + return false +} + +// SetRoundCanceled - +func (cnsm *ConsensusStateMock) SetRoundCanceled(state bool) { + if cnsm.SetRoundCanceledCalled != nil { + cnsm.SetRoundCanceledCalled(state) + } +} + +// GetRoundIndex - +func (cnsm *ConsensusStateMock) GetRoundIndex() int64 { + if cnsm.GetRoundIndexCalled != nil { + return cnsm.GetRoundIndexCalled() + } + return 0 +} + +// SetRoundIndex - +func (cnsm *ConsensusStateMock) SetRoundIndex(roundIndex int64) { + if cnsm.SetRoundIndexCalled != nil { + cnsm.SetRoundIndexCalled(roundIndex) + } +} + +// GetRoundTimeStamp - +func (cnsm *ConsensusStateMock) GetRoundTimeStamp() time.Time { + if cnsm.GetRoundTimeStampCalled != nil { + return cnsm.GetRoundTimeStampCalled() + } + return time.Time{} +} + +// SetRoundTimeStamp - +func (cnsm *ConsensusStateMock) SetRoundTimeStamp(roundTimeStamp time.Time) { + if cnsm.SetRoundTimeStampCalled != nil { + cnsm.SetRoundTimeStampCalled(roundTimeStamp) + } +} + +// GetExtendedCalled - +func (cnsm *ConsensusStateMock) GetExtendedCalled() bool { + if cnsm.GetExtendedCalledCalled != nil { + return cnsm.GetExtendedCalledCalled() + } + return false +} + +// GetBody - +func (cnsm *ConsensusStateMock) GetBody() data.BodyHandler { + if cnsm.GetBodyCalled != nil { + return cnsm.GetBodyCalled() + } + return nil +} + +// SetBody - +func (cnsm *ConsensusStateMock) SetBody(body data.BodyHandler) { + if cnsm.SetBodyCalled != nil { + cnsm.SetBodyCalled(body) + } +} + +// GetHeader - +func (cnsm *ConsensusStateMock) GetHeader() data.HeaderHandler { + if cnsm.GetHeaderCalled != nil { + return cnsm.GetHeaderCalled() + } + return nil +} + +// SetHeader - +func (cnsm *ConsensusStateMock) SetHeader(header data.HeaderHandler) { + if cnsm.SetHeaderCalled != nil { + cnsm.SetHeaderCalled(header) + } +} + +// GetWaitingAllSignaturesTimeOut - +func (cnsm *ConsensusStateMock) GetWaitingAllSignaturesTimeOut() bool { + if cnsm.GetWaitingAllSignaturesTimeOutCalled != nil { + return cnsm.GetWaitingAllSignaturesTimeOutCalled() + } + return false +} + +// SetWaitingAllSignaturesTimeOut - +func (cnsm *ConsensusStateMock) SetWaitingAllSignaturesTimeOut(b bool) { + if cnsm.SetWaitingAllSignaturesTimeOutCalled != nil { + cnsm.SetWaitingAllSignaturesTimeOutCalled(b) + } +} + +// ConsensusGroupIndex - +func (cnsm *ConsensusStateMock) ConsensusGroupIndex(pubKey string) (int, error) { + if cnsm.ConsensusGroupIndexCalled != nil { + return cnsm.ConsensusGroupIndexCalled(pubKey) + } + return 0, nil +} + +// SelfConsensusGroupIndex - +func (cnsm *ConsensusStateMock) SelfConsensusGroupIndex() (int, error) { + if cnsm.SelfConsensusGroupIndexCalled != nil { + return cnsm.SelfConsensusGroupIndexCalled() + } + return 0, nil +} + +// SetEligibleList - +func (cnsm *ConsensusStateMock) SetEligibleList(eligibleList map[string]struct{}) { + if cnsm.SetEligibleListCalled != nil { + cnsm.SetEligibleListCalled(eligibleList) + } +} + +// ConsensusGroup - +func (cnsm *ConsensusStateMock) ConsensusGroup() []string { + if cnsm.ConsensusGroupCalled != nil { + return cnsm.ConsensusGroupCalled() + } + return nil +} + +// SetConsensusGroup - +func (cnsm *ConsensusStateMock) SetConsensusGroup(consensusGroup []string) { + if cnsm.SetConsensusGroupCalled != nil { + cnsm.SetConsensusGroupCalled(consensusGroup) + } +} + +// SetLeader - +func (cnsm *ConsensusStateMock) SetLeader(leader string) { + if cnsm.SetLeaderCalled != nil { + cnsm.SetLeaderCalled(leader) + } +} + +// SetConsensusGroupSize - +func (cnsm *ConsensusStateMock) SetConsensusGroupSize(consensusGroupSize int) { + if cnsm.SetConsensusGroupSizeCalled != nil { + cnsm.SetConsensusGroupSizeCalled(consensusGroupSize) + } +} + +// SelfPubKey - +func (cnsm *ConsensusStateMock) SelfPubKey() string { + if cnsm.SelfPubKeyCalled != nil { + return cnsm.SelfPubKeyCalled() + } + return "" +} + +// SetSelfPubKey - +func (cnsm *ConsensusStateMock) SetSelfPubKey(selfPubKey string) { + if cnsm.SetSelfPubKeyCalled != nil { + cnsm.SetSelfPubKeyCalled(selfPubKey) + } +} + +// JobDone - +func (cnsm *ConsensusStateMock) JobDone(key string, subroundId int) (bool, error) { + if cnsm.JobDoneCalled != nil { + return cnsm.JobDoneCalled(key, subroundId) + } + return false, nil +} + +// SetJobDone - +func (cnsm *ConsensusStateMock) SetJobDone(key string, subroundId int, value bool) error { + if cnsm.SetJobDoneCalled != nil { + return cnsm.SetJobDoneCalled(key, subroundId, value) + } + return nil +} + +// SelfJobDone - +func (cnsm *ConsensusStateMock) SelfJobDone(subroundId int) (bool, error) { + if cnsm.SelfJobDoneCalled != nil { + return cnsm.SelfJobDoneCalled(subroundId) + } + return false, nil +} + +// IsNodeInConsensusGroup - +func (cnsm *ConsensusStateMock) IsNodeInConsensusGroup(node string) bool { + if cnsm.IsNodeInConsensusGroupCalled != nil { + return cnsm.IsNodeInConsensusGroupCalled(node) + } + return false +} + +// IsNodeInEligibleList - +func (cnsm *ConsensusStateMock) IsNodeInEligibleList(node string) bool { + if cnsm.IsNodeInEligibleListCalled != nil { + return cnsm.IsNodeInEligibleListCalled(node) + } + return false +} + +// ComputeSize - +func (cnsm *ConsensusStateMock) ComputeSize(subroundId int) int { + if cnsm.ComputeSizeCalled != nil { + return cnsm.ComputeSizeCalled(subroundId) + } + return 0 +} + +// ResetRoundState - +func (cnsm *ConsensusStateMock) ResetRoundState() { + if cnsm.ResetRoundStateCalled != nil { + cnsm.ResetRoundStateCalled() + } +} + +// IsMultiKeyInConsensusGroup - +func (cnsm *ConsensusStateMock) IsMultiKeyInConsensusGroup() bool { + if cnsm.IsMultiKeyInConsensusGroupCalled != nil { + return cnsm.IsMultiKeyInConsensusGroupCalled() + } + return false +} + +// IsKeyManagedBySelf - +func (cnsm *ConsensusStateMock) IsKeyManagedBySelf(pkBytes []byte) bool { + if cnsm.IsKeyManagedBySelfCalled != nil { + return cnsm.IsKeyManagedBySelfCalled(pkBytes) + } + return false +} + +// IncrementRoundsWithoutReceivedMessages - +func (cnsm *ConsensusStateMock) IncrementRoundsWithoutReceivedMessages(pkBytes []byte) { + if cnsm.IncrementRoundsWithoutReceivedMessagesCalled != nil { + cnsm.IncrementRoundsWithoutReceivedMessagesCalled(pkBytes) + } +} + +// GetKeysHandler - +func (cnsm *ConsensusStateMock) GetKeysHandler() consensus.KeysHandler { + if cnsm.GetKeysHandlerCalled != nil { + return cnsm.GetKeysHandlerCalled() + } + return nil +} + +// Leader - +func (cnsm *ConsensusStateMock) Leader() string { + if cnsm.LeaderCalled != nil { + return cnsm.LeaderCalled() + } + return "" +} + +// Status - +func (cnsm *ConsensusStateMock) Status(subroundId int) int { + if cnsm.StatusCalled != nil { + return cnsm.StatusCalled(subroundId) + } + return 0 +} + +// SetStatus - +func (cnsm *ConsensusStateMock) SetStatus(subroundId int, subroundStatus int) { + if cnsm.SetStatusCalled != nil { + cnsm.SetStatusCalled(subroundId, subroundStatus) + } +} + +// ResetRoundStatus - +func (cnsm *ConsensusStateMock) ResetRoundStatus() { + if cnsm.ResetRoundStatusCalled != nil { + cnsm.ResetRoundStatusCalled() + } +} + +// Threshold - +func (cnsm *ConsensusStateMock) Threshold(subroundId int) int { + if cnsm.ThresholdCalled != nil { + return cnsm.ThresholdCalled(subroundId) + } + return 0 +} + +// FallbackThreshold - +func (cnsm *ConsensusStateMock) FallbackThreshold(subroundId int) int { + if cnsm.FallbackThresholdCalled != nil { + return cnsm.FallbackThresholdCalled(subroundId) + } + return 0 +} + +func (cnsm *ConsensusStateMock) SetFallbackThreshold(subroundId int, threshold int) { + if cnsm.SetFallbackThresholdCalled != nil { + cnsm.SetFallbackThresholdCalled(subroundId, threshold) + } +} + +// ResetConsensusState - +func (cnsm *ConsensusStateMock) ResetConsensusState() { + if cnsm.ResetConsensusStateCalled != nil { + cnsm.ResetConsensusStateCalled() + } +} + +// ResetConsensusRoundState - +func (cnsm *ConsensusStateMock) ResetConsensusRoundState() { + if cnsm.ResetConsensusRoundStateCalled != nil { + cnsm.ResetConsensusRoundStateCalled() + } +} + +// IsNodeLeaderInCurrentRound - +func (cnsm *ConsensusStateMock) IsNodeLeaderInCurrentRound(node string) bool { + if cnsm.IsNodeLeaderInCurrentRoundCalled != nil { + return cnsm.IsNodeLeaderInCurrentRoundCalled(node) + } + return false +} + +// IsSelfLeaderInCurrentRound - +func (cnsm *ConsensusStateMock) IsSelfLeaderInCurrentRound() bool { + if cnsm.IsSelfLeaderInCurrentRoundCalled != nil { + return cnsm.IsSelfLeaderInCurrentRoundCalled() + } + return false +} + +// GetLeader - +func (cnsm *ConsensusStateMock) GetLeader() (string, error) { + if cnsm.GetLeaderCalled != nil { + return cnsm.GetLeaderCalled() + } + return "", nil +} + +// GetNextConsensusGroup - +func (cnsm *ConsensusStateMock) GetNextConsensusGroup( + randomSource []byte, + round uint64, + shardId uint32, + nodesCoordinator nodesCoordinator.NodesCoordinator, + epoch uint32, +) (string, []string, error) { + if cnsm.GetNextConsensusGroupCalled != nil { + return cnsm.GetNextConsensusGroupCalled(randomSource, round, shardId, nodesCoordinator, epoch) + } + return "", nil, nil +} + +// IsConsensusDataSet - +func (cnsm *ConsensusStateMock) IsConsensusDataSet() bool { + if cnsm.IsConsensusDataSetCalled != nil { + return cnsm.IsConsensusDataSetCalled() + } + return false +} + +// IsConsensusDataEqual - +func (cnsm *ConsensusStateMock) IsConsensusDataEqual(data []byte) bool { + if cnsm.IsConsensusDataEqualCalled != nil { + return cnsm.IsConsensusDataEqualCalled(data) + } + return false +} + +// IsJobDone - +func (cnsm *ConsensusStateMock) IsJobDone(node string, currentSubroundId int) bool { + if cnsm.IsJobDoneCalled != nil { + return cnsm.IsJobDoneCalled(node, currentSubroundId) + } + return false +} + +// IsSelfJobDone - +func (cnsm *ConsensusStateMock) IsSelfJobDone(currentSubroundId int) bool { + if cnsm.IsSelfJobDoneCalled != nil { + return cnsm.IsSelfJobDoneCalled(currentSubroundId) + } + return false +} + +// IsCurrentSubroundFinished - +func (cnsm *ConsensusStateMock) IsCurrentSubroundFinished(currentSubroundId int) bool { + if cnsm.IsCurrentSubroundFinishedCalled != nil { + return cnsm.IsCurrentSubroundFinishedCalled(currentSubroundId) + } + return false +} + +// IsNodeSelf - +func (cnsm *ConsensusStateMock) IsNodeSelf(node string) bool { + if cnsm.IsNodeSelfCalled != nil { + return cnsm.IsNodeSelfCalled(node) + } + return false +} + +// IsBlockBodyAlreadyReceived - +func (cnsm *ConsensusStateMock) IsBlockBodyAlreadyReceived() bool { + if cnsm.IsBlockBodyAlreadyReceivedCalled != nil { + return cnsm.IsBlockBodyAlreadyReceivedCalled() + } + return false +} + +// IsHeaderAlreadyReceived - +func (cnsm *ConsensusStateMock) IsHeaderAlreadyReceived() bool { + if cnsm.IsHeaderAlreadyReceivedCalled != nil { + return cnsm.IsHeaderAlreadyReceivedCalled() + } + return false +} + +// CanDoSubroundJob - +func (cnsm *ConsensusStateMock) CanDoSubroundJob(currentSubroundId int) bool { + if cnsm.CanDoSubroundJobCalled != nil { + return cnsm.CanDoSubroundJobCalled(currentSubroundId) + } + return false +} + +// CanProcessReceivedMessage - +func (cnsm *ConsensusStateMock) CanProcessReceivedMessage( + cnsDta *consensus.Message, + currentRoundIndex int64, + currentSubroundId int, +) bool { + return cnsm.CanProcessReceivedMessageCalled(cnsDta, currentRoundIndex, currentSubroundId) +} + +// GenerateBitmap - +func (cnsm *ConsensusStateMock) GenerateBitmap(subroundId int) []byte { + if cnsm.GenerateBitmapCalled != nil { + return cnsm.GenerateBitmapCalled(subroundId) + } + return nil +} + +// ProcessingBlock - +func (cnsm *ConsensusStateMock) ProcessingBlock() bool { + if cnsm.ProcessingBlockCalled != nil { + return cnsm.ProcessingBlockCalled() + } + return false +} + +// SetProcessingBlock - +func (cnsm *ConsensusStateMock) SetProcessingBlock(processingBlock bool) { + if cnsm.SetProcessingBlockCalled != nil { + cnsm.SetProcessingBlockCalled(processingBlock) + } +} + +// ConsensusGroupSize - +func (cnsm *ConsensusStateMock) ConsensusGroupSize() int { + if cnsm.ConsensusGroupSizeCalled != nil { + return cnsm.ConsensusGroupSizeCalled() + } + return 0 +} + +// SetThreshold - +func (cnsm *ConsensusStateMock) SetThreshold(subroundId int, threshold int) { + if cnsm.SetThresholdCalled != nil { + cnsm.SetThresholdCalled(subroundId, threshold) + } +} + +// IsInterfaceNil returns true if there is no value under the interface +func (cnsm *ConsensusStateMock) IsInterfaceNil() bool { + return cnsm == nil +} diff --git a/testscommon/consensus/delayedBroadcasterMock.go b/testscommon/consensus/delayedBroadcasterMock.go new file mode 100644 index 00000000000..9cab4defcc6 --- /dev/null +++ b/testscommon/consensus/delayedBroadcasterMock.go @@ -0,0 +1,74 @@ +package consensus + +import ( + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/consensus" + + "github.com/multiversx/mx-chain-go/consensus/broadcast/shared" +) + +// DelayedBroadcasterMock - +type DelayedBroadcasterMock struct { + SetLeaderDataCalled func(data *shared.DelayedBroadcastData) error + SetValidatorDataCalled func(data *shared.DelayedBroadcastData) error + SetHeaderForValidatorCalled func(vData *shared.ValidatorHeaderBroadcastData) error + SetBroadcastHandlersCalled func( + mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, + txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, + headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, + consensusMessageBroadcast func(message *consensus.Message) error) error + CloseCalled func() +} + +// SetLeaderData - +func (mock *DelayedBroadcasterMock) SetLeaderData(data *shared.DelayedBroadcastData) error { + if mock.SetLeaderDataCalled != nil { + return mock.SetLeaderDataCalled(data) + } + return nil +} + +// SetValidatorData - +func (mock *DelayedBroadcasterMock) SetValidatorData(data *shared.DelayedBroadcastData) error { + if mock.SetValidatorDataCalled != nil { + return mock.SetValidatorDataCalled(data) + } + return nil +} + +// SetHeaderForValidator - +func (mock *DelayedBroadcasterMock) SetHeaderForValidator(vData *shared.ValidatorHeaderBroadcastData) error { + if mock.SetHeaderForValidatorCalled != nil { + return mock.SetHeaderForValidatorCalled(vData) + } + return nil +} + +// SetBroadcastHandlers - +func (mock *DelayedBroadcasterMock) SetBroadcastHandlers( + mbBroadcast func(mbData map[uint32][]byte, pkBytes []byte) error, + txBroadcast func(txData map[string][][]byte, pkBytes []byte) error, + headerBroadcast func(header data.HeaderHandler, pkBytes []byte) error, + consensusMessageBroadcast func(message *consensus.Message) error, +) error { + if mock.SetBroadcastHandlersCalled != nil { + return mock.SetBroadcastHandlersCalled( + mbBroadcast, + txBroadcast, + headerBroadcast, + consensusMessageBroadcast) + } + return nil +} + +// Close - +func (mock *DelayedBroadcasterMock) Close() { + if mock.CloseCalled != nil { + mock.CloseCalled() + } +} + +// IsInterfaceNil returns true if there is no value under the interface +func (mock *DelayedBroadcasterMock) IsInterfaceNil() bool { + return mock == nil +} diff --git a/consensus/mock/hasherStub.go b/testscommon/consensus/hasherStub.go similarity index 97% rename from consensus/mock/hasherStub.go rename to testscommon/consensus/hasherStub.go index f05c2fd2cc8..05bea1aaa6d 100644 --- a/consensus/mock/hasherStub.go +++ b/testscommon/consensus/hasherStub.go @@ -1,4 +1,4 @@ -package mock +package consensus // HasherStub - type HasherStub struct { diff --git a/testscommon/consensus/headerSigVerifierStub.go b/testscommon/consensus/headerSigVerifierStub.go new file mode 100644 index 00000000000..c50992c3c45 --- /dev/null +++ b/testscommon/consensus/headerSigVerifierStub.go @@ -0,0 +1,72 @@ +package consensus + +import "github.com/multiversx/mx-chain-core-go/data" + +// HeaderSigVerifierMock - +type HeaderSigVerifierMock struct { + VerifyRandSeedAndLeaderSignatureCalled func(header data.HeaderHandler) error + VerifySignatureCalled func(header data.HeaderHandler) error + VerifyRandSeedCalled func(header data.HeaderHandler) error + VerifyLeaderSignatureCalled func(header data.HeaderHandler) error + VerifySignatureForHashCalled func(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error + VerifyHeaderProofCalled func(proofHandler data.HeaderProofHandler) error +} + +// VerifyRandSeed - +func (mock *HeaderSigVerifierMock) VerifyRandSeed(header data.HeaderHandler) error { + if mock.VerifyRandSeedCalled != nil { + return mock.VerifyRandSeedCalled(header) + } + + return nil +} + +// VerifyRandSeedAndLeaderSignature - +func (mock *HeaderSigVerifierMock) VerifyRandSeedAndLeaderSignature(header data.HeaderHandler) error { + if mock.VerifyRandSeedAndLeaderSignatureCalled != nil { + return mock.VerifyRandSeedAndLeaderSignatureCalled(header) + } + + return nil +} + +// VerifySignature - +func (mock *HeaderSigVerifierMock) VerifySignature(header data.HeaderHandler) error { + if mock.VerifySignatureCalled != nil { + return mock.VerifySignatureCalled(header) + } + + return nil +} + +// VerifyLeaderSignature - +func (mock *HeaderSigVerifierMock) VerifyLeaderSignature(header data.HeaderHandler) error { + if mock.VerifyLeaderSignatureCalled != nil { + return mock.VerifyLeaderSignatureCalled(header) + } + + return nil +} + +// VerifySignatureForHash - +func (mock *HeaderSigVerifierMock) VerifySignatureForHash(header data.HeaderHandler, hash []byte, pubkeysBitmap []byte, signature []byte) error { + if mock.VerifySignatureForHashCalled != nil { + return mock.VerifySignatureForHashCalled(header, hash, pubkeysBitmap, signature) + } + + return nil +} + +// VerifyHeaderProof - +func (mock *HeaderSigVerifierMock) VerifyHeaderProof(proofHandler data.HeaderProofHandler) error { + if mock.VerifyHeaderProofCalled != nil { + return mock.VerifyHeaderProofCalled(proofHandler) + } + + return nil +} + +// IsInterfaceNil - +func (mock *HeaderSigVerifierMock) IsInterfaceNil() bool { + return mock == nil +} diff --git a/testscommon/consensus/initializers/initializers.go b/testscommon/consensus/initializers/initializers.go new file mode 100644 index 00000000000..187c8f02892 --- /dev/null +++ b/testscommon/consensus/initializers/initializers.go @@ -0,0 +1,156 @@ +package initializers + +import ( + crypto "github.com/multiversx/mx-chain-crypto-go" + "golang.org/x/exp/slices" + + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/spos" + "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" + "github.com/multiversx/mx-chain-go/testscommon" +) + +func createEligibleList(size int) []string { + eligibleList := make([]string, 0) + for i := 0; i < size; i++ { + eligibleList = append(eligibleList, string([]byte{byte(i + 65)})) + } + return eligibleList +} + +// CreateEligibleListFromMap creates a list of eligible nodes from a map of private keys +func CreateEligibleListFromMap(mapKeys map[string]crypto.PrivateKey) []string { + eligibleList := make([]string, 0, len(mapKeys)) + for key := range mapKeys { + eligibleList = append(eligibleList, key) + } + slices.Sort(eligibleList) + return eligibleList +} + +// InitConsensusStateWithNodesCoordinator creates a consensus state with a nodes coordinator +func InitConsensusStateWithNodesCoordinator(validatorsGroupSelector nodesCoordinator.NodesCoordinator) *spos.ConsensusState { + return initConsensusStateWithKeysHandlerAndNodesCoordinator(&testscommon.KeysHandlerStub{}, validatorsGroupSelector) +} + +// InitConsensusState creates a consensus state +func InitConsensusState() *spos.ConsensusState { + return InitConsensusStateWithKeysHandler(&testscommon.KeysHandlerStub{}) +} + +// InitConsensusStateWithArgs creates a consensus state the given arguments +func InitConsensusStateWithArgs(keysHandler consensus.KeysHandler, mapKeys map[string]crypto.PrivateKey) *spos.ConsensusState { + return initConsensusStateWithKeysHandlerWithGroupSizeWithRealKeys(keysHandler, mapKeys) +} + +// InitConsensusStateWithKeysHandler creates a consensus state with a keys handler +func InitConsensusStateWithKeysHandler(keysHandler consensus.KeysHandler) *spos.ConsensusState { + consensusGroupSize := 9 + return initConsensusStateWithKeysHandlerWithGroupSize(keysHandler, consensusGroupSize) +} + +func initConsensusStateWithKeysHandlerAndNodesCoordinator(keysHandler consensus.KeysHandler, validatorsGroupSelector nodesCoordinator.NodesCoordinator) *spos.ConsensusState { + leader, consensusValidators, _ := validatorsGroupSelector.GetConsensusValidatorsPublicKeys([]byte("randomness"), 0, 0, 0) + eligibleNodesPubKeys := make(map[string]struct{}) + for _, key := range consensusValidators { + eligibleNodesPubKeys[key] = struct{}{} + } + return createConsensusStateWithNodes(eligibleNodesPubKeys, consensusValidators, leader, keysHandler) +} + +// InitConsensusStateWithArgsVerifySignature creates a consensus state with the given arguments for signature verification +func InitConsensusStateWithArgsVerifySignature(keysHandler consensus.KeysHandler, keys []string) *spos.ConsensusState { + numberOfKeys := len(keys) + eligibleNodesPubKeys := make(map[string]struct{}, numberOfKeys) + for _, key := range keys { + eligibleNodesPubKeys[key] = struct{}{} + } + + indexLeader := 1 + rcns, _ := spos.NewRoundConsensus( + eligibleNodesPubKeys, + numberOfKeys, + keys[indexLeader], + keysHandler, + ) + rcns.SetConsensusGroup(keys) + rcns.ResetRoundState() + + pBFTThreshold := numberOfKeys*2/3 + 1 + pBFTFallbackThreshold := numberOfKeys*1/2 + 1 + rthr := spos.NewRoundThreshold() + rthr.SetThreshold(1, 1) + rthr.SetThreshold(2, pBFTThreshold) + rthr.SetFallbackThreshold(1, 1) + rthr.SetFallbackThreshold(2, pBFTFallbackThreshold) + + rstatus := spos.NewRoundStatus() + rstatus.ResetRoundStatus() + cns := spos.NewConsensusState( + rcns, + rthr, + rstatus, + ) + cns.Data = []byte("X") + cns.SetRoundIndex(0) + + return cns +} + +func initConsensusStateWithKeysHandlerWithGroupSize(keysHandler consensus.KeysHandler, consensusGroupSize int) *spos.ConsensusState { + eligibleList := createEligibleList(consensusGroupSize) + + eligibleNodesPubKeys := make(map[string]struct{}) + for _, key := range eligibleList { + eligibleNodesPubKeys[key] = struct{}{} + } + + return createConsensusStateWithNodes(eligibleNodesPubKeys, eligibleList, eligibleList[0], keysHandler) +} + +func initConsensusStateWithKeysHandlerWithGroupSizeWithRealKeys(keysHandler consensus.KeysHandler, mapKeys map[string]crypto.PrivateKey) *spos.ConsensusState { + eligibleList := CreateEligibleListFromMap(mapKeys) + + eligibleNodesPubKeys := make(map[string]struct{}, len(eligibleList)) + for _, key := range eligibleList { + eligibleNodesPubKeys[key] = struct{}{} + } + + return createConsensusStateWithNodes(eligibleNodesPubKeys, eligibleList, eligibleList[0], keysHandler) +} + +func createConsensusStateWithNodes(eligibleNodesPubKeys map[string]struct{}, consensusValidators []string, leader string, keysHandler consensus.KeysHandler) *spos.ConsensusState { + consensusGroupSize := len(consensusValidators) + rcns, _ := spos.NewRoundConsensus( + eligibleNodesPubKeys, + consensusGroupSize, + consensusValidators[1], + keysHandler, + ) + + rcns.SetConsensusGroup(consensusValidators) + rcns.SetLeader(leader) + rcns.ResetRoundState() + + pBFTThreshold := consensusGroupSize*2/3 + 1 + pBFTFallbackThreshold := consensusGroupSize*1/2 + 1 + + rthr := spos.NewRoundThreshold() + rthr.SetThreshold(1, 1) + rthr.SetThreshold(2, pBFTThreshold) + rthr.SetFallbackThreshold(1, 1) + rthr.SetFallbackThreshold(2, pBFTFallbackThreshold) + + rstatus := spos.NewRoundStatus() + rstatus.ResetRoundStatus() + + cns := spos.NewConsensusState( + rcns, + rthr, + rstatus, + ) + + cns.Data = []byte("X") + cns.SetRoundIndex(0) + return cns +} diff --git a/testscommon/consensus/invalidSignersCacheMock.go b/testscommon/consensus/invalidSignersCacheMock.go new file mode 100644 index 00000000000..f8c51387b70 --- /dev/null +++ b/testscommon/consensus/invalidSignersCacheMock.go @@ -0,0 +1,35 @@ +package consensus + +// InvalidSignersCacheMock - +type InvalidSignersCacheMock struct { + AddInvalidSignersCalled func(headerHash []byte, invalidSigners []byte, invalidPublicKeys []string) + CheckKnownInvalidSignersCalled func(headerHash []byte, invalidSigners []byte) bool + ResetCalled func() +} + +// AddInvalidSigners - +func (mock *InvalidSignersCacheMock) AddInvalidSigners(headerHash []byte, invalidSigners []byte, invalidPublicKeys []string) { + if mock.AddInvalidSignersCalled != nil { + mock.AddInvalidSignersCalled(headerHash, invalidSigners, invalidPublicKeys) + } +} + +// CheckKnownInvalidSigners - +func (mock *InvalidSignersCacheMock) CheckKnownInvalidSigners(headerHash []byte, invalidSigners []byte) bool { + if mock.CheckKnownInvalidSignersCalled != nil { + return mock.CheckKnownInvalidSignersCalled(headerHash, invalidSigners) + } + return false +} + +// Reset - +func (mock *InvalidSignersCacheMock) Reset() { + if mock.ResetCalled != nil { + mock.ResetCalled() + } +} + +// IsInterfaceNil - +func (mock *InvalidSignersCacheMock) IsInterfaceNil() bool { + return mock == nil +} diff --git a/consensus/mock/mockTestInitializer.go b/testscommon/consensus/mockTestInitializer.go similarity index 65% rename from consensus/mock/mockTestInitializer.go rename to testscommon/consensus/mockTestInitializer.go index 6fa62a5a49d..85b946c13df 100644 --- a/consensus/mock/mockTestInitializer.go +++ b/testscommon/consensus/mockTestInitializer.go @@ -1,4 +1,4 @@ -package mock +package consensus import ( "time" @@ -7,11 +7,18 @@ import ( "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-core-go/marshal" crypto "github.com/multiversx/mx-chain-crypto-go" + "github.com/multiversx/mx-chain-go/consensus" + "github.com/multiversx/mx-chain-go/consensus/mock" + "github.com/multiversx/mx-chain-go/consensus/spos" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/testscommon" - consensusMocks "github.com/multiversx/mx-chain-go/testscommon/consensus" + "github.com/multiversx/mx-chain-go/testscommon/bootstrapperStubs" "github.com/multiversx/mx-chain-go/testscommon/cryptoMocks" + "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + epochNotifierMock "github.com/multiversx/mx-chain-go/testscommon/epochNotifier" + epochstartmock "github.com/multiversx/mx-chain-go/testscommon/epochstartmock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/shardingMocks" ) @@ -119,14 +126,14 @@ func InitMultiSignerMock() *cryptoMocks.MultisignerMock { } // InitKeys - -func InitKeys() (*KeyGenMock, *PrivateKeyMock, *PublicKeyMock) { +func InitKeys() (*mock.KeyGenMock, *mock.PrivateKeyMock, *mock.PublicKeyMock) { toByteArrayMock := func() ([]byte, error) { return []byte("byteArray"), nil } - privKeyMock := &PrivateKeyMock{ + privKeyMock := &mock.PrivateKeyMock{ ToByteArrayMock: toByteArrayMock, } - pubKeyMock := &PublicKeyMock{ + pubKeyMock := &mock.PublicKeyMock{ ToByteArrayMock: toByteArrayMock, } privKeyFromByteArr := func(b []byte) (crypto.PrivateKey, error) { @@ -135,7 +142,7 @@ func InitKeys() (*KeyGenMock, *PrivateKeyMock, *PublicKeyMock) { pubKeyFromByteArr := func(b []byte) (crypto.PublicKey, error) { return pubKeyMock, nil } - keyGenMock := &KeyGenMock{ + keyGenMock := &mock.KeyGenMock{ PrivateKeyFromByteArrayMock: privKeyFromByteArr, PublicKeyFromByteArrayMock: pubKeyFromByteArr, } @@ -143,30 +150,32 @@ func InitKeys() (*KeyGenMock, *PrivateKeyMock, *PublicKeyMock) { } // InitConsensusCoreHeaderV2 - -func InitConsensusCoreHeaderV2() *ConsensusCoreMock { +func InitConsensusCoreHeaderV2() *spos.ConsensusCore { consensusCoreMock := InitConsensusCore() - consensusCoreMock.blockProcessor = InitBlockProcessorHeaderV2Mock() + consensusCoreMock.SetBlockProcessor(InitBlockProcessorHeaderV2Mock()) return consensusCoreMock } // InitConsensusCore - -func InitConsensusCore() *ConsensusCoreMock { +func InitConsensusCore() *spos.ConsensusCore { multiSignerMock := InitMultiSignerMock() return InitConsensusCoreWithMultiSigner(multiSignerMock) } // InitConsensusCoreWithMultiSigner - -func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *ConsensusCoreMock { +func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *spos.ConsensusCore { blockChain := &testscommon.ChainHandlerStub{ GetGenesisHeaderCalled: func() data.HeaderHandler { - return &block.Header{} + return &block.Header{ + RandSeed: []byte("randSeed"), + } }, } - marshalizerMock := MarshalizerMock{} + marshalizerMock := mock.MarshalizerMock{} blockProcessorMock := InitBlockProcessorMock(marshalizerMock) - bootstrapperMock := &BootstrapperStub{} + bootstrapperMock := &bootstrapperStubs.BootstrapperStub{} broadcastMessengerMock := &BroadcastMessengerMock{ BroadcastConsensusMessageCalled: func(message *consensus.Message) error { return nil @@ -176,13 +185,14 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus chronologyHandlerMock := InitChronologyHandlerMock() hasherMock := &hashingMocks.HasherMock{} roundHandlerMock := &RoundHandlerMock{} - shardCoordinatorMock := ShardCoordinatorMock{} + shardCoordinatorMock := mock.ShardCoordinatorMock{} syncTimerMock := &SyncTimerMock{} - validatorGroupSelector := &shardingMocks.NodesCoordinatorMock{ - ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]nodesCoordinator.Validator, error) { + nodesCoordinator := &shardingMocks.NodesCoordinatorMock{ + ComputeValidatorsGroupCalled: func(randomness []byte, round uint64, shardId uint32, epoch uint32) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { defaultSelectionChances := uint32(1) - return []nodesCoordinator.Validator{ - shardingMocks.NewValidatorMock([]byte("A"), 1, defaultSelectionChances), + leader := shardingMocks.NewValidatorMock([]byte("A"), 1, defaultSelectionChances) + return leader, []nodesCoordinator.Validator{ + leader, shardingMocks.NewValidatorMock([]byte("B"), 1, defaultSelectionChances), shardingMocks.NewValidatorMock([]byte("C"), 1, defaultSelectionChances), shardingMocks.NewValidatorMock([]byte("D"), 1, defaultSelectionChances), @@ -194,44 +204,49 @@ func InitConsensusCoreWithMultiSigner(multiSigner crypto.MultiSigner) *Consensus }, nil }, } - epochStartSubscriber := &EpochStartNotifierStub{} - antifloodHandler := &P2PAntifloodHandlerStub{} - headerPoolSubscriber := &HeadersCacherStub{} + epochStartSubscriber := &epochstartmock.EpochStartNotifierStub{} + antifloodHandler := &mock.P2PAntifloodHandlerStub{} peerHonestyHandler := &testscommon.PeerHonestyHandlerStub{} - headerSigVerifier := &HeaderSigVerifierStub{} + headerSigVerifier := &HeaderSigVerifierMock{} fallbackHeaderValidator := &testscommon.FallBackHeaderValidatorStub{} - nodeRedundancyHandler := &NodeRedundancyHandlerStub{} - scheduledProcessor := &consensusMocks.ScheduledProcessorStub{} - messageSigningHandler := &MessageSigningHandlerStub{} - peerBlacklistHandler := &PeerBlacklistHandlerStub{} + nodeRedundancyHandler := &mock.NodeRedundancyHandlerStub{} + scheduledProcessor := &ScheduledProcessorStub{} + messageSigningHandler := &mock.MessageSigningHandlerStub{} + peerBlacklistHandler := &mock.PeerBlacklistHandlerStub{} multiSignerContainer := cryptoMocks.NewMultiSignerContainerMock(multiSigner) - signingHandler := &consensusMocks.SigningHandlerStub{} - - container := &ConsensusCoreMock{ - blockChain: blockChain, - blockProcessor: blockProcessorMock, - headersSubscriber: headerPoolSubscriber, - bootstrapper: bootstrapperMock, - broadcastMessenger: broadcastMessengerMock, - chronologyHandler: chronologyHandlerMock, - hasher: hasherMock, - marshalizer: marshalizerMock, - multiSignerContainer: multiSignerContainer, - roundHandler: roundHandlerMock, - shardCoordinator: shardCoordinatorMock, - syncTimer: syncTimerMock, - validatorGroupSelector: validatorGroupSelector, - epochStartNotifier: epochStartSubscriber, - antifloodHandler: antifloodHandler, - peerHonestyHandler: peerHonestyHandler, - headerSigVerifier: headerSigVerifier, - fallbackHeaderValidator: fallbackHeaderValidator, - nodeRedundancyHandler: nodeRedundancyHandler, - scheduledProcessor: scheduledProcessor, - messageSigningHandler: messageSigningHandler, - peerBlacklistHandler: peerBlacklistHandler, - signingHandler: signingHandler, - } + signingHandler := &SigningHandlerStub{} + enableEpochsHandler := &enableEpochsHandlerMock.EnableEpochsHandlerStub{} + equivalentProofsPool := &dataRetriever.ProofsPoolMock{} + epochNotifier := &epochNotifierMock.EpochNotifierStub{} + + container, _ := spos.NewConsensusCore(&spos.ConsensusCoreArgs{ + BlockChain: blockChain, + BlockProcessor: blockProcessorMock, + Bootstrapper: bootstrapperMock, + BroadcastMessenger: broadcastMessengerMock, + ChronologyHandler: chronologyHandlerMock, + Hasher: hasherMock, + Marshalizer: marshalizerMock, + MultiSignerContainer: multiSignerContainer, + RoundHandler: roundHandlerMock, + ShardCoordinator: shardCoordinatorMock, + SyncTimer: syncTimerMock, + NodesCoordinator: nodesCoordinator, + EpochStartRegistrationHandler: epochStartSubscriber, + AntifloodHandler: antifloodHandler, + PeerHonestyHandler: peerHonestyHandler, + HeaderSigVerifier: headerSigVerifier, + FallbackHeaderValidator: fallbackHeaderValidator, + NodeRedundancyHandler: nodeRedundancyHandler, + ScheduledProcessor: scheduledProcessor, + MessageSigningHandler: messageSigningHandler, + PeerBlacklistHandler: peerBlacklistHandler, + SigningHandler: signingHandler, + EnableEpochsHandler: enableEpochsHandler, + EquivalentProofsPool: equivalentProofsPool, + EpochNotifier: epochNotifier, + InvalidSignersCache: &InvalidSignersCacheMock{}, + }) return container } diff --git a/consensus/mock/rounderMock.go b/testscommon/consensus/rounderMock.go similarity index 94% rename from consensus/mock/rounderMock.go rename to testscommon/consensus/rounderMock.go index 6a0625932a1..2855033a823 100644 --- a/consensus/mock/rounderMock.go +++ b/testscommon/consensus/rounderMock.go @@ -1,4 +1,4 @@ -package mock +package consensus import ( "time" @@ -24,6 +24,9 @@ func (rndm *RoundHandlerMock) BeforeGenesis() bool { return false } +// RevertOneRound - +func (rndm *RoundHandlerMock) RevertOneRound() {} + // Index - func (rndm *RoundHandlerMock) Index() int64 { if rndm.IndexCalled != nil { diff --git a/consensus/mock/sposWorkerMock.go b/testscommon/consensus/sposWorkerMock.go similarity index 55% rename from consensus/mock/sposWorkerMock.go rename to testscommon/consensus/sposWorkerMock.go index 0454370bedf..657f01ca7ca 100644 --- a/consensus/mock/sposWorkerMock.go +++ b/testscommon/consensus/sposWorkerMock.go @@ -1,10 +1,11 @@ -package mock +package consensus import ( "context" "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/p2p" ) @@ -16,8 +17,10 @@ type SposWorkerMock struct { receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool, ) AddReceivedHeaderHandlerCalled func(handler func(data.HeaderHandler)) + RemoveAllReceivedHeaderHandlersCalled func() + AddReceivedProofHandlerCalled func(handler func(proofHandler consensus.ProofHandler)) RemoveAllReceivedMessagesCallsCalled func() - ProcessReceivedMessageCalled func(message p2p.MessageP2P) error + ProcessReceivedMessageCalled func(message p2p.MessageP2P) ([]byte, error) SendConsensusMessageCalled func(cnsDta *consensus.Message) bool ExtendCalled func(subroundId int) GetConsensusStateChangedChannelsCalled func() chan bool @@ -28,12 +31,25 @@ type SposWorkerMock struct { ReceivedHeaderCalled func(headerHandler data.HeaderHandler, headerHash []byte) SetAppStatusHandlerCalled func(ash core.AppStatusHandler) error ResetConsensusMessagesCalled func() + ResetConsensusStateCalled func() + ReceivedProofCalled func(proofHandler consensus.ProofHandler) + ResetConsensusRoundStateCalled func() + ResetInvalidSignersCacheCalled func() +} + +// ResetConsensusRoundState - +func (sposWorkerMock *SposWorkerMock) ResetConsensusRoundState() { + if sposWorkerMock.ResetConsensusRoundStateCalled != nil { + sposWorkerMock.ResetConsensusRoundStateCalled() + } } // AddReceivedMessageCall - func (sposWorkerMock *SposWorkerMock) AddReceivedMessageCall(messageType consensus.MessageType, receivedMessageCall func(ctx context.Context, cnsDta *consensus.Message) bool) { - sposWorkerMock.AddReceivedMessageCallCalled(messageType, receivedMessageCall) + if sposWorkerMock.AddReceivedMessageCallCalled != nil { + sposWorkerMock.AddReceivedMessageCallCalled(messageType, receivedMessageCall) + } } // AddReceivedHeaderHandler - @@ -43,39 +59,71 @@ func (sposWorkerMock *SposWorkerMock) AddReceivedHeaderHandler(handler func(data } } +// RemoveAllReceivedHeaderHandlers - +func (sposWorkerMock *SposWorkerMock) RemoveAllReceivedHeaderHandlers() { + if sposWorkerMock.RemoveAllReceivedHeaderHandlersCalled != nil { + sposWorkerMock.RemoveAllReceivedHeaderHandlersCalled() + } +} + +func (sposWorkerMock *SposWorkerMock) AddReceivedProofHandler(handler func(proofHandler consensus.ProofHandler)) { + if sposWorkerMock.AddReceivedProofHandlerCalled != nil { + sposWorkerMock.AddReceivedProofHandlerCalled(handler) + } +} + // RemoveAllReceivedMessagesCalls - func (sposWorkerMock *SposWorkerMock) RemoveAllReceivedMessagesCalls() { - sposWorkerMock.RemoveAllReceivedMessagesCallsCalled() + if sposWorkerMock.RemoveAllReceivedMessagesCallsCalled != nil { + sposWorkerMock.RemoveAllReceivedMessagesCallsCalled() + } } // ProcessReceivedMessage - -func (sposWorkerMock *SposWorkerMock) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) error { - return sposWorkerMock.ProcessReceivedMessageCalled(message) +func (sposWorkerMock *SposWorkerMock) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) ([]byte, error) { + if sposWorkerMock.ProcessReceivedMessageCalled == nil { + return sposWorkerMock.ProcessReceivedMessageCalled(message) + } + return nil, nil } // SendConsensusMessage - func (sposWorkerMock *SposWorkerMock) SendConsensusMessage(cnsDta *consensus.Message) bool { - return sposWorkerMock.SendConsensusMessageCalled(cnsDta) + if sposWorkerMock.SendConsensusMessageCalled != nil { + return sposWorkerMock.SendConsensusMessageCalled(cnsDta) + } + return false } // Extend - func (sposWorkerMock *SposWorkerMock) Extend(subroundId int) { - sposWorkerMock.ExtendCalled(subroundId) + if sposWorkerMock.ExtendCalled != nil { + sposWorkerMock.ExtendCalled(subroundId) + } } // GetConsensusStateChangedChannel - func (sposWorkerMock *SposWorkerMock) GetConsensusStateChangedChannel() chan bool { - return sposWorkerMock.GetConsensusStateChangedChannelsCalled() + if sposWorkerMock.GetConsensusStateChangedChannelsCalled != nil { + return sposWorkerMock.GetConsensusStateChangedChannelsCalled() + } + + return nil } // BroadcastBlock - func (sposWorkerMock *SposWorkerMock) BroadcastBlock(body data.BodyHandler, header data.HeaderHandler) error { - return sposWorkerMock.GetBroadcastBlockCalled(body, header) + if sposWorkerMock.GetBroadcastBlockCalled != nil { + return sposWorkerMock.GetBroadcastBlockCalled(body, header) + } + return nil } // ExecuteStoredMessages - func (sposWorkerMock *SposWorkerMock) ExecuteStoredMessages() { - sposWorkerMock.ExecuteStoredMessagesCalled() + if sposWorkerMock.ExecuteStoredMessagesCalled != nil { + sposWorkerMock.ExecuteStoredMessagesCalled() + } } // DisplayStatistics - @@ -108,7 +156,28 @@ func (sposWorkerMock *SposWorkerMock) ResetConsensusMessages() { } } +// ResetConsensusState - +func (sposWorkerMock *SposWorkerMock) ResetConsensusState() { + if sposWorkerMock.ResetConsensusStateCalled != nil { + sposWorkerMock.ResetConsensusStateCalled() + } +} + +// ReceivedProof - +func (sposWorkerMock *SposWorkerMock) ReceivedProof(proofHandler consensus.ProofHandler) { + if sposWorkerMock.ReceivedProofCalled != nil { + sposWorkerMock.ReceivedProofCalled(proofHandler) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (sposWorkerMock *SposWorkerMock) IsInterfaceNil() bool { return sposWorkerMock == nil } + +// ResetInvalidSignersCache - +func (sposWorkerMock *SposWorkerMock) ResetInvalidSignersCache() { + if sposWorkerMock.ResetInvalidSignersCacheCalled != nil { + sposWorkerMock.ResetInvalidSignersCacheCalled() + } +} diff --git a/consensus/mock/syncTimerMock.go b/testscommon/consensus/syncTimerMock.go similarity index 98% rename from consensus/mock/syncTimerMock.go rename to testscommon/consensus/syncTimerMock.go index 2fa41d42341..32b92bbe33b 100644 --- a/consensus/mock/syncTimerMock.go +++ b/testscommon/consensus/syncTimerMock.go @@ -1,4 +1,4 @@ -package mock +package consensus import ( "time" diff --git a/testscommon/dataRetriever/poolFactory.go b/testscommon/dataRetriever/poolFactory.go index 71a27a718f2..bff8b653001 100644 --- a/testscommon/dataRetriever/poolFactory.go +++ b/testscommon/dataRetriever/poolFactory.go @@ -6,10 +6,12 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/headersCache" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/dataRetriever/shardedData" "github.com/multiversx/mx-chain-go/dataRetriever/txpool" "github.com/multiversx/mx-chain-go/storage/cache" @@ -46,8 +48,7 @@ func CreateTxPool(numShards uint32, selfShard uint32) (dataRetriever.ShardedData ) } -// CreatePoolsHolder - -func CreatePoolsHolder(numShards uint32, selfShard uint32) dataRetriever.PoolsHolder { +func createPoolHolderArgs(numShards uint32, selfShard uint32) dataPool.DataPoolArgs { var err error txPool, err := CreateTxPool(numShards, selfShard) @@ -134,6 +135,8 @@ func CreatePoolsHolder(numShards uint32, selfShard uint32) dataRetriever.PoolsHo }) panicIfError("CreatePoolsHolder", err) + proofsPool := proofscache.NewProofsPool(3, 100) + currentBlockTransactions := dataPool.NewCurrentBlockTransactionsPool() currentEpochValidatorInfo := dataPool.NewCurrentEpochValidatorInfoPool() dataPoolArgs := dataPool.DataPoolArgs{ @@ -151,13 +154,37 @@ func CreatePoolsHolder(numShards uint32, selfShard uint32) dataRetriever.PoolsHo PeerAuthentications: peerAuthPool, Heartbeats: heartbeatPool, ValidatorsInfo: validatorsInfo, + Proofs: proofsPool, } + + return dataPoolArgs +} + +// CreatePoolsHolder - +func CreatePoolsHolder(numShards uint32, selfShard uint32) dataRetriever.PoolsHolder { + + dataPoolArgs := createPoolHolderArgs(numShards, selfShard) + holder, err := dataPool.NewDataPool(dataPoolArgs) panicIfError("CreatePoolsHolder", err) return holder } +// CreatePoolsHolderWithProofsPool - +func CreatePoolsHolderWithProofsPool( + numShards uint32, selfShard uint32, + proofsPool dataRetriever.ProofsPool, +) dataRetriever.PoolsHolder { + dataPoolArgs := createPoolHolderArgs(numShards, selfShard) + dataPoolArgs.Proofs = proofsPool + + holder, err := dataPool.NewDataPool(dataPoolArgs) + panicIfError("CreatePoolsHolderWithProofsPool", err) + + return holder +} + // CreatePoolsHolderWithTxPool - func CreatePoolsHolderWithTxPool(txPool dataRetriever.ShardedDataCacherNotifier) dataRetriever.PoolsHolder { var err error @@ -218,6 +245,8 @@ func CreatePoolsHolderWithTxPool(txPool dataRetriever.ShardedDataCacherNotifier) heartbeatPool, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolderWithTxPool", err) + proofsPool := proofscache.NewProofsPool(3, 100) + currentBlockTransactions := dataPool.NewCurrentBlockTransactionsPool() currentEpochValidatorInfo := dataPool.NewCurrentEpochValidatorInfoPool() dataPoolArgs := dataPool.DataPoolArgs{ @@ -235,6 +264,7 @@ func CreatePoolsHolderWithTxPool(txPool dataRetriever.ShardedDataCacherNotifier) PeerAuthentications: peerAuthPool, Heartbeats: heartbeatPool, ValidatorsInfo: validatorsInfo, + Proofs: proofsPool, } holder, err := dataPool.NewDataPool(dataPoolArgs) panicIfError("CreatePoolsHolderWithTxPool", err) diff --git a/testscommon/dataRetriever/poolsHolderMock.go b/testscommon/dataRetriever/poolsHolderMock.go index 75321b6854c..6dc5266c062 100644 --- a/testscommon/dataRetriever/poolsHolderMock.go +++ b/testscommon/dataRetriever/poolsHolderMock.go @@ -10,6 +10,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool" "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/headersCache" + proofscache "github.com/multiversx/mx-chain-go/dataRetriever/dataPool/proofsCache" "github.com/multiversx/mx-chain-go/dataRetriever/shardedData" "github.com/multiversx/mx-chain-go/dataRetriever/txpool" "github.com/multiversx/mx-chain-go/storage" @@ -34,6 +35,7 @@ type PoolsHolderMock struct { peerAuthentications storage.Cacher heartbeats storage.Cacher validatorsInfo dataRetriever.ShardedDataCacherNotifier + proofs dataRetriever.ProofsPool } // NewPoolsHolderMock - @@ -108,6 +110,8 @@ func NewPoolsHolderMock() *PoolsHolderMock { }) panicIfError("NewPoolsHolderMock", err) + holder.proofs = proofscache.NewProofsPool(3, 100) + return holder } @@ -196,6 +200,11 @@ func (holder *PoolsHolderMock) ValidatorsInfo() dataRetriever.ShardedDataCacherN return holder.validatorsInfo } +// Proofs - +func (holder *PoolsHolderMock) Proofs() dataRetriever.ProofsPool { + return holder.proofs +} + // Close - func (holder *PoolsHolderMock) Close() error { var lastError error diff --git a/testscommon/dataRetriever/poolsHolderStub.go b/testscommon/dataRetriever/poolsHolderStub.go index 106c8b96bb5..7d9051d6f10 100644 --- a/testscommon/dataRetriever/poolsHolderStub.go +++ b/testscommon/dataRetriever/poolsHolderStub.go @@ -4,6 +4,7 @@ import ( "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" ) // PoolsHolderStub - @@ -23,6 +24,7 @@ type PoolsHolderStub struct { PeerAuthenticationsCalled func() storage.Cacher HeartbeatsCalled func() storage.Cacher ValidatorsInfoCalled func() dataRetriever.ShardedDataCacherNotifier + ProofsCalled func() dataRetriever.ProofsPool CloseCalled func() error } @@ -73,7 +75,7 @@ func (holder *PoolsHolderStub) MiniBlocks() storage.Cacher { return holder.MiniBlocksCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // MetaBlocks - @@ -82,7 +84,7 @@ func (holder *PoolsHolderStub) MetaBlocks() storage.Cacher { return holder.MetaBlocksCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // CurrentBlockTxs - @@ -109,7 +111,7 @@ func (holder *PoolsHolderStub) TrieNodes() storage.Cacher { return holder.TrieNodesCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // TrieNodesChunks - @@ -118,7 +120,7 @@ func (holder *PoolsHolderStub) TrieNodesChunks() storage.Cacher { return holder.TrieNodesChunksCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // PeerChangesBlocks - @@ -127,7 +129,7 @@ func (holder *PoolsHolderStub) PeerChangesBlocks() storage.Cacher { return holder.PeerChangesBlocksCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // SmartContracts - @@ -136,7 +138,7 @@ func (holder *PoolsHolderStub) SmartContracts() storage.Cacher { return holder.SmartContractsCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // PeerAuthentications - @@ -145,7 +147,7 @@ func (holder *PoolsHolderStub) PeerAuthentications() storage.Cacher { return holder.PeerAuthenticationsCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // Heartbeats - @@ -154,7 +156,7 @@ func (holder *PoolsHolderStub) Heartbeats() storage.Cacher { return holder.HeartbeatsCalled() } - return testscommon.NewCacherStub() + return cache.NewCacherStub() } // ValidatorsInfo - @@ -166,6 +168,15 @@ func (holder *PoolsHolderStub) ValidatorsInfo() dataRetriever.ShardedDataCacherN return testscommon.NewShardedDataStub() } +// Proofs - +func (holder *PoolsHolderStub) Proofs() dataRetriever.ProofsPool { + if holder.ProofsCalled != nil { + return holder.ProofsCalled() + } + + return nil +} + // Close - func (holder *PoolsHolderStub) Close() error { if holder.CloseCalled != nil { diff --git a/testscommon/dataRetriever/proofsPoolMock.go b/testscommon/dataRetriever/proofsPoolMock.go new file mode 100644 index 00000000000..c7450a411fb --- /dev/null +++ b/testscommon/dataRetriever/proofsPoolMock.go @@ -0,0 +1,93 @@ +package dataRetriever + +import ( + "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-core-go/data/block" +) + +// ProofsPoolMock - +type ProofsPoolMock struct { + AddProofCalled func(headerProof data.HeaderProofHandler) bool + UpsertProofCalled func(headerProof data.HeaderProofHandler) bool + CleanupProofsBehindNonceCalled func(shardID uint32, nonce uint64) error + GetProofCalled func(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) + GetProofByNonceCalled func(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) + HasProofCalled func(shardID uint32, headerHash []byte) bool + IsProofInPoolEqualToCalled func(headerProof data.HeaderProofHandler) bool + RegisterHandlerCalled func(handler func(headerProof data.HeaderProofHandler)) +} + +// AddProof - +func (p *ProofsPoolMock) AddProof(headerProof data.HeaderProofHandler) bool { + if p.AddProofCalled != nil { + return p.AddProofCalled(headerProof) + } + + return true +} + +// UpsertProof - +func (p *ProofsPoolMock) UpsertProof(headerProof data.HeaderProofHandler) bool { + if p.UpsertProofCalled != nil { + return p.UpsertProofCalled(headerProof) + } + + return true +} + +// CleanupProofsBehindNonce - +func (p *ProofsPoolMock) CleanupProofsBehindNonce(shardID uint32, nonce uint64) error { + if p.CleanupProofsBehindNonceCalled != nil { + return p.CleanupProofsBehindNonceCalled(shardID, nonce) + } + + return nil +} + +// GetProof - +func (p *ProofsPoolMock) GetProof(shardID uint32, headerHash []byte) (data.HeaderProofHandler, error) { + if p.GetProofCalled != nil { + return p.GetProofCalled(shardID, headerHash) + } + + return &block.HeaderProof{}, nil +} + +// GetProofByNonce - +func (p *ProofsPoolMock) GetProofByNonce(headerNonce uint64, shardID uint32) (data.HeaderProofHandler, error) { + if p.GetProofByNonceCalled != nil { + return p.GetProofByNonceCalled(headerNonce, shardID) + } + + return &block.HeaderProof{}, nil +} + +// HasProof - +func (p *ProofsPoolMock) HasProof(shardID uint32, headerHash []byte) bool { + if p.HasProofCalled != nil { + return p.HasProofCalled(shardID, headerHash) + } + + return false +} + +// IsProofInPoolEqualTo - +func (p *ProofsPoolMock) IsProofInPoolEqualTo(headerProof data.HeaderProofHandler) bool { + if p.IsProofInPoolEqualToCalled != nil { + return p.IsProofInPoolEqualToCalled(headerProof) + } + + return false +} + +// RegisterHandler - +func (p *ProofsPoolMock) RegisterHandler(handler func(headerProof data.HeaderProofHandler)) { + if p.RegisterHandlerCalled != nil { + p.RegisterHandlerCalled(handler) + } +} + +// IsInterfaceNil - +func (p *ProofsPoolMock) IsInterfaceNil() bool { + return p == nil +} diff --git a/consensus/mock/epochStartNotifierStub.go b/testscommon/epochstartmock/epochStartNotifierStub.go similarity index 99% rename from consensus/mock/epochStartNotifierStub.go rename to testscommon/epochstartmock/epochStartNotifierStub.go index a671e0f2ead..2072ad30b5a 100644 --- a/consensus/mock/epochStartNotifierStub.go +++ b/testscommon/epochstartmock/epochStartNotifierStub.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/testscommon/factory/coreComponentsHolderStub.go b/testscommon/factory/coreComponentsHolderStub.go index d26a12c33e2..1739c6efc4c 100644 --- a/testscommon/factory/coreComponentsHolderStub.go +++ b/testscommon/factory/coreComponentsHolderStub.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data/typeConverters" "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/consensus" "github.com/multiversx/mx-chain-go/factory" @@ -55,6 +56,10 @@ type CoreComponentsHolderStub struct { HardforkTriggerPubKeyCalled func() []byte EnableEpochsHandlerCalled func() common.EnableEpochsHandler RoundNotifierCalled func() process.RoundNotifier + ChainParametersSubscriberCalled func() process.ChainParametersSubscriber + ChainParametersHandlerCalled func() process.ChainParametersHandler + FieldsSizeCheckerCalled func() common.FieldsSizeChecker + EpochChangeGracePeriodHandlerCalled func() common.EpochChangeGracePeriodHandler } // NewCoreComponentsHolderStubFromRealComponent - @@ -95,6 +100,10 @@ func NewCoreComponentsHolderStubFromRealComponent(coreComponents factory.CoreCom HardforkTriggerPubKeyCalled: coreComponents.HardforkTriggerPubKey, EnableEpochsHandlerCalled: coreComponents.EnableEpochsHandler, RoundNotifierCalled: coreComponents.RoundNotifier, + ChainParametersHandlerCalled: coreComponents.ChainParametersHandler, + ChainParametersSubscriberCalled: coreComponents.ChainParametersSubscriber, + FieldsSizeCheckerCalled: coreComponents.FieldsSizeChecker, + EpochChangeGracePeriodHandlerCalled: coreComponents.EpochChangeGracePeriodHandler, } } @@ -378,6 +387,38 @@ func (stub *CoreComponentsHolderStub) RoundNotifier() process.RoundNotifier { return nil } +// ChainParametersSubscriber - +func (stub *CoreComponentsHolderStub) ChainParametersSubscriber() process.ChainParametersSubscriber { + if stub.ChainParametersSubscriberCalled != nil { + return stub.ChainParametersSubscriberCalled() + } + return nil +} + +// ChainParametersHandler - +func (stub *CoreComponentsHolderStub) ChainParametersHandler() process.ChainParametersHandler { + if stub.ChainParametersHandlerCalled != nil { + return stub.ChainParametersHandlerCalled() + } + return nil +} + +// FieldsSizeChecker - +func (stub *CoreComponentsHolderStub) FieldsSizeChecker() common.FieldsSizeChecker { + if stub.FieldsSizeCheckerCalled != nil { + return stub.FieldsSizeCheckerCalled() + } + return nil +} + +// EpochChangeGracePeriodHandler - +func (stub *CoreComponentsHolderStub) EpochChangeGracePeriodHandler() common.EpochChangeGracePeriodHandler { + if stub.EpochChangeGracePeriodHandlerCalled != nil { + return stub.EpochChangeGracePeriodHandlerCalled() + } + return nil +} + // IsInterfaceNil - func (stub *CoreComponentsHolderStub) IsInterfaceNil() bool { return stub == nil diff --git a/testscommon/factory/stateComponentsMock.go b/testscommon/factory/stateComponentsMock.go index 5aa541dffa0..0adb3f3bc10 100644 --- a/testscommon/factory/stateComponentsMock.go +++ b/testscommon/factory/stateComponentsMock.go @@ -16,6 +16,7 @@ type StateComponentsMock struct { Tries common.TriesHolder StorageManagers map[string]common.StorageManager MissingNodesNotifier common.MissingTrieNodesNotifier + LeavesRetriever common.TrieLeavesRetriever } // NewStateComponentsMockFromRealComponent - @@ -28,6 +29,7 @@ func NewStateComponentsMockFromRealComponent(stateComponents factory.StateCompon Tries: stateComponents.TriesContainer(), StorageManagers: stateComponents.TrieStorageManagers(), MissingNodesNotifier: stateComponents.MissingTrieNodesNotifier(), + LeavesRetriever: stateComponents.TrieLeavesRetriever(), } } @@ -89,6 +91,11 @@ func (scm *StateComponentsMock) MissingTrieNodesNotifier() common.MissingTrieNod return scm.MissingNodesNotifier } +// TrieLeavesRetriever - +func (scm *StateComponentsMock) TrieLeavesRetriever() common.TrieLeavesRetriever { + return scm.LeavesRetriever +} + // IsInterfaceNil - func (scm *StateComponentsMock) IsInterfaceNil() bool { return scm == nil diff --git a/testscommon/fallbackHeaderValidatorStub.go b/testscommon/fallbackHeaderValidatorStub.go index b769aa94976..2ba582c7118 100644 --- a/testscommon/fallbackHeaderValidatorStub.go +++ b/testscommon/fallbackHeaderValidatorStub.go @@ -6,7 +6,16 @@ import ( // FallBackHeaderValidatorStub - type FallBackHeaderValidatorStub struct { - ShouldApplyFallbackValidationCalled func(headerHandler data.HeaderHandler) bool + ShouldApplyFallbackValidationCalled func(headerHandler data.HeaderHandler) bool + ShouldApplyFallbackValidationForHeaderWithCalled func(shardID uint32, startOfEpochBlock bool, round uint64, prevHeaderHash []byte) bool +} + +// ShouldApplyFallbackValidationForHeaderWith - +func (fhvs *FallBackHeaderValidatorStub) ShouldApplyFallbackValidationForHeaderWith(shardID uint32, startOfEpochBlock bool, round uint64, prevHeaderHash []byte) bool { + if fhvs.ShouldApplyFallbackValidationForHeaderWithCalled != nil { + return fhvs.ShouldApplyFallbackValidationForHeaderWithCalled(shardID, startOfEpochBlock, round, prevHeaderHash) + } + return false } // ShouldApplyFallbackValidation - diff --git a/testscommon/fieldsSizeCheckerMock.go b/testscommon/fieldsSizeCheckerMock.go new file mode 100644 index 00000000000..f87f778c5c0 --- /dev/null +++ b/testscommon/fieldsSizeCheckerMock.go @@ -0,0 +1,22 @@ +package testscommon + +import "github.com/multiversx/mx-chain-core-go/data" + +// FieldsSizeCheckerMock - +type FieldsSizeCheckerMock struct { + IsProofSizeValidCalled func(proof data.HeaderProofHandler) bool +} + +// IsProofSizeValid - +func (mock *FieldsSizeCheckerMock) IsProofSizeValid(proof data.HeaderProofHandler) bool { + if mock.IsProofSizeValidCalled != nil { + return mock.IsProofSizeValidCalled(proof) + } + + return true +} + +// IsInterfaceNil - +func (mock *FieldsSizeCheckerMock) IsInterfaceNil() bool { + return mock == nil +} diff --git a/testscommon/generalConfig.go b/testscommon/generalConfig.go index 1eea96a2bdb..3448122e630 100644 --- a/testscommon/generalConfig.go +++ b/testscommon/generalConfig.go @@ -57,6 +57,19 @@ func GetGeneralConfig() config.Config { SyncProcessTimeInMillis: 6000, SetGuardianEpochsDelay: 20, StatusPollingIntervalSec: 10, + ChainParametersByEpoch: []config.ChainParametersByEpochConfig{ + { + EnableEpoch: 0, + RoundDuration: 6000, + ShardConsensusGroupSize: 1, + ShardMinNumNodes: 1, + MetachainConsensusGroupSize: 1, + MetachainMinNumNodes: 1, + Hysteresis: 0, + Adaptivity: false, + }, + }, + EpochChangeGracePeriodByEpoch: []config.EpochChangeGracePeriodByEpoch{{EnableEpoch: 0, GracePeriodInRounds: 1}}, }, EpochStartConfig: config.EpochStartConfig{ MinRoundsBetweenEpochs: 5, @@ -198,6 +211,16 @@ func GetGeneralConfig() config.Config { MaxOpenFiles: 10, }, }, + ProofsStorage: config.StorageConfig{ + Cache: getLRUCacheConfig(), + DB: config.DBConfig{ + FilePath: AddTimestampSuffix("Proofs"), + Type: string(storageunit.MemoryDB), + BatchDelaySeconds: 30, + MaxBatchSize: 6, + MaxOpenFiles: 10, + }, + }, MetaHdrNonceHashStorage: config.StorageConfig{ Cache: getLRUCacheConfig(), DB: config.DBConfig{ @@ -381,9 +404,9 @@ func GetGeneralConfig() config.Config { {StartEpoch: 0, Version: "*"}, }, TransferAndExecuteByUserAddresses: []string{ - "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", //shard 0 - "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", //shard 1 - "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", //shard 2 + "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", // shard 0 + "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", // shard 1 + "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", // shard 2 }, }, Querying: config.QueryVirtualMachineConfig{ @@ -393,9 +416,9 @@ func GetGeneralConfig() config.Config { {StartEpoch: 0, Version: "*"}, }, TransferAndExecuteByUserAddresses: []string{ - "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", //shard 0 - "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", //shard 1 - "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", //shard 2 + "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", // shard 0 + "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", // shard 1 + "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", // shard 2 }, }, }, @@ -415,20 +438,24 @@ func GetGeneralConfig() config.Config { }, BuiltInFunctions: config.BuiltInFunctionsConfig{ AutomaticCrawlerAddresses: []string{ - "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", //shard 0 - "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", //shard 1 - "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", //shard 2 + "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", // shard 0 + "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", // shard 1 + "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", // shard 2 }, MaxNumAddressesInTransferRole: 100, DNSV2Addresses: []string{ - "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", //shard 0 - "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", //shard 1 - "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", //shard 2 + "erd1he8wwxn4az3j82p7wwqsdk794dm7hcrwny6f8dfegkfla34udx7qrf7xje", // shard 0 + "erd1fpkcgel4gcmh8zqqdt043yfcn5tyx8373kg6q2qmkxzu4dqamc0swts65c", // shard 1 + "erd1najnxxweyw6plhg8efql330nttrj6l5cf87wqsuym85s9ha0hmdqnqgenp", // shard 2 }, }, ResourceStats: config.ResourceStatsConfig{ RefreshIntervalInSec: 1, }, + InterceptedDataVerifier: config.InterceptedDataVerifierConfig{ + CacheSpanInSec: 1, + CacheExpiryInSec: 1, + }, } } diff --git a/testscommon/genericMocks/chainStorerMock.go b/testscommon/genericMocks/chainStorerMock.go index d8453ea2aa2..f06560d6075 100644 --- a/testscommon/genericMocks/chainStorerMock.go +++ b/testscommon/genericMocks/chainStorerMock.go @@ -18,6 +18,7 @@ type ChainStorerMock struct { ShardHdrNonce *StorerMock Receipts *StorerMock ScheduledSCRs *StorerMock + Proofs *StorerMock Others *StorerMock } @@ -35,6 +36,7 @@ func NewChainStorerMock(epoch uint32) *ChainStorerMock { ShardHdrNonce: NewStorerMockWithEpoch(epoch), Receipts: NewStorerMockWithEpoch(epoch), ScheduledSCRs: NewStorerMockWithEpoch(epoch), + Proofs: NewStorerMockWithErrKeyNotFound(epoch), Others: NewStorerMockWithEpoch(epoch), } } @@ -74,6 +76,8 @@ func (sm *ChainStorerMock) GetStorer(unitType dataRetriever.UnitType) (storage.S return sm.Receipts, nil case dataRetriever.ScheduledSCRsUnit: return sm.ScheduledSCRs, nil + case dataRetriever.ProofsUnit: + return sm.Proofs, nil } // According to: dataRetriever/interface.go @@ -147,6 +151,7 @@ func (sm *ChainStorerMock) GetAllStorers() map[dataRetriever.UnitType]storage.St dataRetriever.ShardHdrNonceHashDataUnit: sm.ShardHdrNonce, dataRetriever.ReceiptsUnit: sm.Receipts, dataRetriever.ScheduledSCRsUnit: sm.ScheduledSCRs, + dataRetriever.ProofsUnit: sm.Proofs, } } diff --git a/testscommon/genesisMocks/nodesSetupStub.go b/testscommon/genesisMocks/nodesSetupStub.go index ebe1cfe778a..e06a881dbf2 100644 --- a/testscommon/genesisMocks/nodesSetupStub.go +++ b/testscommon/genesisMocks/nodesSetupStub.go @@ -1,6 +1,7 @@ package genesisMocks import ( + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" ) @@ -27,6 +28,7 @@ type NodesSetupStub struct { MinMetaHysteresisNodesCalled func() uint32 GetChainIdCalled func() string GetMinTransactionVersionCalled func() uint32 + ExportNodesConfigCalled func() config.NodesConfig } // InitialNodesPubKeys - @@ -203,6 +205,15 @@ func (n *NodesSetupStub) MinMetaHysteresisNodes() uint32 { return 1 } +// ExportNodesConfig - +func (n *NodesSetupStub) ExportNodesConfig() config.NodesConfig { + if n.ExportNodesConfigCalled != nil { + return n.ExportNodesConfigCalled() + } + + return config.NodesConfig{} +} + // IsInterfaceNil - func (n *NodesSetupStub) IsInterfaceNil() bool { return n == nil diff --git a/testscommon/headerHandlerStub.go b/testscommon/headerHandlerStub.go index ab1d354ec60..3ce4bfd22e1 100644 --- a/testscommon/headerHandlerStub.go +++ b/testscommon/headerHandlerStub.go @@ -38,6 +38,8 @@ type HeaderHandlerStub struct { SetRandSeedCalled func(seed []byte) error SetSignatureCalled func(signature []byte) error SetLeaderSignatureCalled func(signature []byte) error + GetShardIDCalled func() uint32 + SetRootHashCalled func(hash []byte) error } // GetAccumulatedFees - @@ -89,6 +91,9 @@ func (hhs *HeaderHandlerStub) ShallowClone() data.HeaderHandler { // GetShardID - func (hhs *HeaderHandlerStub) GetShardID() uint32 { + if hhs.GetShardIDCalled != nil { + return hhs.GetShardIDCalled() + } return 1 } @@ -199,8 +204,11 @@ func (hhs *HeaderHandlerStub) SetTimeStamp(timestamp uint64) error { } // SetRootHash - -func (hhs *HeaderHandlerStub) SetRootHash(_ []byte) error { - panic("implement me") +func (hhs *HeaderHandlerStub) SetRootHash(hash []byte) error { + if hhs.SetRootHashCalled != nil { + return hhs.SetRootHashCalled(hash) + } + return nil } // SetPrevHash - diff --git a/testscommon/interceptorContainerStub.go b/testscommon/interceptorContainerStub.go index ea07dbd8857..6b183fd6cc5 100644 --- a/testscommon/interceptorContainerStub.go +++ b/testscommon/interceptorContainerStub.go @@ -31,8 +31,8 @@ func (ics *InterceptorsContainerStub) Get(topic string) (process.Interceptor, er } return &InterceptorStub{ - ProcessReceivedMessageCalled: func(message p2p.MessageP2P) error { - return nil + ProcessReceivedMessageCalled: func(message p2p.MessageP2P) ([]byte, error) { + return nil, nil }, }, nil } diff --git a/testscommon/interceptorStub.go b/testscommon/interceptorStub.go index 54fc5be30af..095006e1f4e 100644 --- a/testscommon/interceptorStub.go +++ b/testscommon/interceptorStub.go @@ -2,25 +2,26 @@ package testscommon import ( "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/p2p" "github.com/multiversx/mx-chain-go/process" ) // InterceptorStub - type InterceptorStub struct { - ProcessReceivedMessageCalled func(message p2p.MessageP2P) error + ProcessReceivedMessageCalled func(message p2p.MessageP2P) ([]byte, error) SetInterceptedDebugHandlerCalled func(debugger process.InterceptedDebugger) error RegisterHandlerCalled func(handler func(topic string, hash []byte, data interface{})) CloseCalled func() error } // ProcessReceivedMessage - -func (is *InterceptorStub) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) error { +func (is *InterceptorStub) ProcessReceivedMessage(message p2p.MessageP2P, _ core.PeerID, _ p2p.MessageHandler) ([]byte, error) { if is.ProcessReceivedMessageCalled != nil { return is.ProcessReceivedMessageCalled(message) } - return nil + return nil, nil } // SetInterceptedDebugHandler - diff --git a/testscommon/outport/outportStub.go b/testscommon/outport/outportStub.go index e9cd2649d3e..c6a2996036b 100644 --- a/testscommon/outport/outportStub.go +++ b/testscommon/outport/outportStub.go @@ -11,6 +11,7 @@ type OutportStub struct { SaveValidatorsRatingCalled func(validatorsRating *outportcore.ValidatorsRating) SaveValidatorsPubKeysCalled func(validatorsPubKeys *outportcore.ValidatorsPubKeys) HasDriversCalled func() bool + SaveRoundsInfoCalled func(roundsInfo *outportcore.RoundsInfo) } // SaveBlock - @@ -65,7 +66,10 @@ func (as *OutportStub) Close() error { } // SaveRoundsInfo - -func (as *OutportStub) SaveRoundsInfo(_ *outportcore.RoundsInfo) { +func (as *OutportStub) SaveRoundsInfo(roundsInfo *outportcore.RoundsInfo) { + if as.SaveRoundsInfoCalled != nil { + as.SaveRoundsInfoCalled(roundsInfo) + } } diff --git a/testscommon/p2pmocks/messageProcessorStub.go b/testscommon/p2pmocks/messageProcessorStub.go index 5802dcc6785..69b8079e4d1 100644 --- a/testscommon/p2pmocks/messageProcessorStub.go +++ b/testscommon/p2pmocks/messageProcessorStub.go @@ -7,16 +7,16 @@ import ( // MessageProcessorStub - type MessageProcessorStub struct { - ProcessReceivedMessageCalled func(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error + ProcessReceivedMessageCalled func(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) } // ProcessReceivedMessage - -func (stub *MessageProcessorStub) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error { +func (stub *MessageProcessorStub) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) { if stub.ProcessReceivedMessageCalled != nil { return stub.ProcessReceivedMessageCalled(message, fromConnectedPeer, source) } - return nil + return nil, nil } // IsInterfaceNil - diff --git a/testscommon/p2pmocks/messengerStub.go b/testscommon/p2pmocks/messengerStub.go index c48c95b9868..48a19977f07 100644 --- a/testscommon/p2pmocks/messengerStub.go +++ b/testscommon/p2pmocks/messengerStub.go @@ -4,6 +4,7 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/p2p" ) @@ -43,7 +44,7 @@ type MessengerStub struct { BroadcastUsingPrivateKeyCalled func(topic string, buff []byte, pid core.PeerID, skBytes []byte) BroadcastOnChannelUsingPrivateKeyCalled func(channel string, topic string, buff []byte, pid core.PeerID, skBytes []byte) SignUsingPrivateKeyCalled func(skBytes []byte, payload []byte) ([]byte, error) - ProcessReceivedMessageCalled func(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error + ProcessReceivedMessageCalled func(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) SetDebuggerCalled func(debugger p2p.Debugger) error HasCompatibleProtocolIDCalled func(address string) bool } @@ -345,11 +346,11 @@ func (ms *MessengerStub) SignUsingPrivateKey(skBytes []byte, payload []byte) ([] } // ProcessReceivedMessage - -func (ms *MessengerStub) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) error { +func (ms *MessengerStub) ProcessReceivedMessage(message p2p.MessageP2P, fromConnectedPeer core.PeerID, source p2p.MessageHandler) ([]byte, error) { if ms.ProcessReceivedMessageCalled != nil { return ms.ProcessReceivedMessageCalled(message, fromConnectedPeer, source) } - return nil + return nil, nil } // SetDebugger - diff --git a/testscommon/processMocks/forkDetectorStub.go b/testscommon/processMocks/forkDetectorStub.go index 80ddc4d2ebf..a6e4f2e2621 100644 --- a/testscommon/processMocks/forkDetectorStub.go +++ b/testscommon/processMocks/forkDetectorStub.go @@ -2,6 +2,7 @@ package processMocks import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/process" ) @@ -19,6 +20,8 @@ type ForkDetectorStub struct { RestoreToGenesisCalled func() ResetProbableHighestNonceCalled func() SetFinalToLastCheckpointCalled func() + ReceivedProofCalled func(proof data.HeaderProofHandler) + AddCheckpointCalled func(nonce uint64, round uint64, hash []byte) } // RestoreToGenesis - @@ -93,6 +96,20 @@ func (fdm *ForkDetectorStub) SetFinalToLastCheckpoint() { } } +// ReceivedProof - +func (fdm *ForkDetectorStub) ReceivedProof(proof data.HeaderProofHandler) { + if fdm.ReceivedProofCalled != nil { + fdm.ReceivedProofCalled(proof) + } +} + +// AddCheckpoint - +func (fdm *ForkDetectorStub) AddCheckpoint(nonce uint64, round uint64, hash []byte) { + if fdm.AddCheckpointCalled != nil { + fdm.AddCheckpointCalled(nonce, round, hash) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (fdm *ForkDetectorStub) IsInterfaceNil() bool { return fdm == nil diff --git a/testscommon/processMocks/headerProofHandlerStub.go b/testscommon/processMocks/headerProofHandlerStub.go new file mode 100644 index 00000000000..92c8ea2daf9 --- /dev/null +++ b/testscommon/processMocks/headerProofHandlerStub.go @@ -0,0 +1,82 @@ +package processMocks + +// HeaderProofHandlerStub - +type HeaderProofHandlerStub struct { + GetPubKeysBitmapCalled func() []byte + GetAggregatedSignatureCalled func() []byte + GetHeaderHashCalled func() []byte + GetHeaderEpochCalled func() uint32 + GetHeaderNonceCalled func() uint64 + GetHeaderShardIdCalled func() uint32 + GetHeaderRoundCalled func() uint64 + GetIsStartOfEpochCalled func() bool +} + +// GetPubKeysBitmap - +func (h *HeaderProofHandlerStub) GetPubKeysBitmap() []byte { + if h.GetPubKeysBitmapCalled != nil { + return h.GetPubKeysBitmapCalled() + } + return nil +} + +// GetAggregatedSignature - +func (h *HeaderProofHandlerStub) GetAggregatedSignature() []byte { + if h.GetAggregatedSignatureCalled != nil { + return h.GetAggregatedSignatureCalled() + } + return nil +} + +// GetHeaderHash - +func (h *HeaderProofHandlerStub) GetHeaderHash() []byte { + if h.GetHeaderHashCalled != nil { + return h.GetHeaderHashCalled() + } + return nil +} + +// GetHeaderEpoch - +func (h *HeaderProofHandlerStub) GetHeaderEpoch() uint32 { + if h.GetHeaderEpochCalled != nil { + return h.GetHeaderEpochCalled() + } + return 0 +} + +// GetHeaderNonce - +func (h *HeaderProofHandlerStub) GetHeaderNonce() uint64 { + if h.GetHeaderNonceCalled != nil { + return h.GetHeaderNonceCalled() + } + return 0 +} + +// GetHeaderShardId - +func (h *HeaderProofHandlerStub) GetHeaderShardId() uint32 { + if h.GetHeaderShardIdCalled != nil { + return h.GetHeaderShardIdCalled() + } + return 0 +} + +// GetHeaderRound - +func (h *HeaderProofHandlerStub) GetHeaderRound() uint64 { + if h.GetHeaderRoundCalled != nil { + return h.GetHeaderRoundCalled() + } + return 0 +} + +// GetIsStartOfEpoch - +func (h *HeaderProofHandlerStub) GetIsStartOfEpoch() bool { + if h.GetIsStartOfEpochCalled != nil { + return h.GetIsStartOfEpochCalled() + } + return false +} + +// IsInterfaceNil - +func (h *HeaderProofHandlerStub) IsInterfaceNil() bool { + return h == nil +} diff --git a/testscommon/ratingsInfoMock.go b/testscommon/ratingsInfoMock.go index 39f1f7897cd..5ca786994e8 100644 --- a/testscommon/ratingsInfoMock.go +++ b/testscommon/ratingsInfoMock.go @@ -1,6 +1,9 @@ package testscommon -import "github.com/multiversx/mx-chain-go/process" +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-go/process" +) // RatingsInfoMock - type RatingsInfoMock struct { @@ -11,6 +14,7 @@ type RatingsInfoMock struct { MetaRatingsStepDataProperty process.RatingsStepHandler ShardRatingsStepDataProperty process.RatingsStepHandler SelectionChancesProperty []process.SelectionChance + SetStatusHandlerCalled func(handler core.AppStatusHandler) error } // StartRating - @@ -48,6 +52,14 @@ func (rd *RatingsInfoMock) ShardChainRatingsStepHandler() process.RatingsStepHan return rd.ShardRatingsStepDataProperty } +// SetStatusHandler - +func (rd *RatingsInfoMock) SetStatusHandler(handler core.AppStatusHandler) error { + if rd.SetStatusHandlerCalled != nil { + return rd.SetStatusHandlerCalled(handler) + } + return nil +} + // IsInterfaceNil - func (rd *RatingsInfoMock) IsInterfaceNil() bool { return rd == nil diff --git a/testscommon/realConfigsHandling.go b/testscommon/realConfigsHandling.go index e58b36923f8..c59661b7234 100644 --- a/testscommon/realConfigsHandling.go +++ b/testscommon/realConfigsHandling.go @@ -6,6 +6,7 @@ import ( "path" "strings" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" ) @@ -83,6 +84,12 @@ func CreateTestConfigs(tempDir string, originalConfigsPath string) (*config.Conf return nil, err } + var nodesSetup config.NodesConfig + err = core.LoadJsonFile(&nodesSetup, path.Join(newConfigsPath, "nodesSetup.json")) + if err != nil { + return nil, err + } + // make the node pass the network wait constraints mainP2PConfig.Node.MinNumPeersToWaitForOnBootstrap = 0 mainP2PConfig.Node.ThresholdMinConnectedPeers = 0 @@ -114,6 +121,7 @@ func CreateTestConfigs(tempDir string, originalConfigsPath string) (*config.Conf }, EpochConfig: epochConfig, RoundConfig: roundConfig, + NodesConfig: &nodesSetup, }, nil } diff --git a/testscommon/requestHandlerStub.go b/testscommon/requestHandlerStub.go index 395e78e3100..3f911cd79a3 100644 --- a/testscommon/requestHandlerStub.go +++ b/testscommon/requestHandlerStub.go @@ -22,6 +22,8 @@ type RequestHandlerStub struct { RequestPeerAuthenticationsByHashesCalled func(destShardID uint32, hashes [][]byte) RequestValidatorInfoCalled func(hash []byte) RequestValidatorsInfoCalled func(hashes [][]byte) + RequestEquivalentProofByHashCalled func(headerShard uint32, headerHash []byte) + RequestEquivalentProofByNonceCalled func(headerShard uint32, headerNonce uint64) } // SetNumPeersToQuery - @@ -176,6 +178,20 @@ func (rhs *RequestHandlerStub) RequestValidatorsInfo(hashes [][]byte) { } } +// RequestEquivalentProofByHash - +func (rhs *RequestHandlerStub) RequestEquivalentProofByHash(headerShard uint32, headerHash []byte) { + if rhs.RequestEquivalentProofByHashCalled != nil { + rhs.RequestEquivalentProofByHashCalled(headerShard, headerHash) + } +} + +// RequestEquivalentProofByNonce - +func (rhs *RequestHandlerStub) RequestEquivalentProofByNonce(headerShard uint32, headerNonce uint64) { + if rhs.RequestEquivalentProofByNonceCalled != nil { + rhs.RequestEquivalentProofByNonceCalled(headerShard, headerNonce) + } +} + // IsInterfaceNil returns true if there is no value under the interface func (rhs *RequestHandlerStub) IsInterfaceNil() bool { return rhs == nil diff --git a/testscommon/roundHandlerMock.go b/testscommon/roundHandlerMock.go index 6c5d45cc7bc..598e11feb42 100644 --- a/testscommon/roundHandlerMock.go +++ b/testscommon/roundHandlerMock.go @@ -27,6 +27,9 @@ func (rndm *RoundHandlerMock) BeforeGenesis() bool { return false } +// RevertOneRound - +func (rndm *RoundHandlerMock) RevertOneRound() {} + // Index - func (rndm *RoundHandlerMock) Index() int64 { if rndm.IndexCalled != nil { diff --git a/testscommon/shardedDataCacheNotifierMock.go b/testscommon/shardedDataCacheNotifierMock.go index d5af2000ab3..f6043415b08 100644 --- a/testscommon/shardedDataCacheNotifierMock.go +++ b/testscommon/shardedDataCacheNotifierMock.go @@ -4,7 +4,9 @@ import ( "sync" "github.com/multiversx/mx-chain-core-go/core/counting" + "github.com/multiversx/mx-chain-go/storage" + cacheMocks "github.com/multiversx/mx-chain-go/testscommon/cache" ) // ShardedDataCacheNotifierMock - @@ -31,7 +33,7 @@ func (mock *ShardedDataCacheNotifierMock) ShardDataStore(cacheId string) (c stor cache, found := mock.caches[cacheId] if !found { - cache = NewCacherMock() + cache = cacheMocks.NewCacherMock() mock.caches[cacheId] = cache } diff --git a/testscommon/shardingMocks/nodesCoordinatorMock.go b/testscommon/shardingMocks/nodesCoordinatorMock.go index 9f1b872e2ab..033c394a91f 100644 --- a/testscommon/shardingMocks/nodesCoordinatorMock.go +++ b/testscommon/shardingMocks/nodesCoordinatorMock.go @@ -18,18 +18,21 @@ type NodesCoordinatorMock struct { ShardId uint32 NbShards uint32 GetSelectedPublicKeysCalled func(selection []byte, shardId uint32, epoch uint32) (publicKeys []string, err error) - GetValidatorsPublicKeysCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) + GetValidatorsPublicKeysCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) (string, []string, error) GetValidatorsRewardsAddressesCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) SetNodesPerShardsCalled func(nodes map[uint32][]nodesCoordinator.Validator, epoch uint32) error - ComputeValidatorsGroupCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) + ComputeValidatorsGroupCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) GetValidatorWithPublicKeyCalled func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) GetAllEligibleValidatorsPublicKeysCalled func(epoch uint32) (map[uint32][][]byte, error) + GetAllEligibleValidatorsPublicKeysForShardCalled func(epoch uint32, shardID uint32) ([]string, error) GetAllWaitingValidatorsPublicKeysCalled func() (map[uint32][][]byte, error) - ConsensusGroupSizeCalled func(uint32) int + ConsensusGroupSizeCalled func(uint32, uint32) int GetValidatorsIndexesCalled func(publicKeys []string, epoch uint32) ([]uint64, error) + GetConsensusWhitelistedNodesCalled func(epoch uint32) (map[string]struct{}, error) GetAllShuffledOutValidatorsPublicKeysCalled func(epoch uint32) (map[uint32][][]byte, error) GetShuffledOutToAuctionValidatorsPublicKeysCalled func(epoch uint32) (map[uint32][][]byte, error) GetNumTotalEligibleCalled func() uint64 + GetCachedEpochsCalled func() map[uint32]struct{} } // NewNodesCoordinatorMock - @@ -96,6 +99,14 @@ func (ncm *NodesCoordinatorMock) GetAllEligibleValidatorsPublicKeys(epoch uint32 return nil, nil } +// GetAllEligibleValidatorsPublicKeysForShard - +func (ncm *NodesCoordinatorMock) GetAllEligibleValidatorsPublicKeysForShard(epoch uint32, shardID uint32) ([]string, error) { + if ncm.GetAllEligibleValidatorsPublicKeysForShardCalled != nil { + return ncm.GetAllEligibleValidatorsPublicKeysForShardCalled(epoch, shardID) + } + return nil, nil +} + // GetAllWaitingValidatorsPublicKeys - func (ncm *NodesCoordinatorMock) GetAllWaitingValidatorsPublicKeys(_ uint32) (map[uint32][][]byte, error) { if ncm.GetAllWaitingValidatorsPublicKeysCalled != nil { @@ -155,14 +166,14 @@ func (ncm *NodesCoordinatorMock) GetConsensusValidatorsPublicKeys( round uint64, shardId uint32, epoch uint32, -) ([]string, error) { +) (string, []string, error) { if ncm.GetValidatorsPublicKeysCalled != nil { return ncm.GetValidatorsPublicKeysCalled(randomness, round, shardId, epoch) } - validators, err := ncm.ComputeConsensusGroup(randomness, round, shardId, epoch) + leader, validators, err := ncm.ComputeConsensusGroup(randomness, round, shardId, epoch) if err != nil { - return nil, err + return "", nil, err } valGrStr := make([]string, 0) @@ -171,7 +182,7 @@ func (ncm *NodesCoordinatorMock) GetConsensusValidatorsPublicKeys( valGrStr = append(valGrStr, string(v.PubKey())) } - return valGrStr, nil + return string(leader.PubKey()), valGrStr, nil } // SetNodesPerShards - @@ -204,7 +215,7 @@ func (ncm *NodesCoordinatorMock) ComputeConsensusGroup( round uint64, shardId uint32, epoch uint32, -) ([]nodesCoordinator.Validator, error) { +) (nodesCoordinator.Validator, []nodesCoordinator.Validator, error) { var consensusSize uint32 if ncm.ComputeValidatorsGroupCalled != nil { @@ -218,7 +229,7 @@ func (ncm *NodesCoordinatorMock) ComputeConsensusGroup( } if randomess == nil { - return nil, nodesCoordinator.ErrNilRandomness + return nil, nil, nodesCoordinator.ErrNilRandomness } validatorsGroup := make([]nodesCoordinator.Validator, 0) @@ -227,13 +238,13 @@ func (ncm *NodesCoordinatorMock) ComputeConsensusGroup( validatorsGroup = append(validatorsGroup, ncm.Validators[shardId][i]) } - return validatorsGroup, nil + return validatorsGroup[0], validatorsGroup, nil } -// ConsensusGroupSize - -func (ncm *NodesCoordinatorMock) ConsensusGroupSize(shardId uint32) int { +// ConsensusGroupSizeForShardAndEpoch - +func (ncm *NodesCoordinatorMock) ConsensusGroupSizeForShardAndEpoch(shardId uint32, epoch uint32) int { if ncm.ConsensusGroupSizeCalled != nil { - return ncm.ConsensusGroupSizeCalled(shardId) + return ncm.ConsensusGroupSizeCalled(shardId, epoch) } return 1 } @@ -285,9 +296,10 @@ func (ncm *NodesCoordinatorMock) ShuffleOutForEpoch(_ uint32) { } // GetConsensusWhitelistedNodes return the whitelisted nodes allowed to send consensus messages, for each of the shards -func (ncm *NodesCoordinatorMock) GetConsensusWhitelistedNodes( - _ uint32, -) (map[string]struct{}, error) { +func (ncm *NodesCoordinatorMock) GetConsensusWhitelistedNodes(epoch uint32) (map[string]struct{}, error) { + if ncm.GetConsensusWhitelistedNodesCalled != nil { + return ncm.GetConsensusWhitelistedNodesCalled(epoch) + } return make(map[string]struct{}), nil } @@ -306,6 +318,15 @@ func (ncm *NodesCoordinatorMock) GetWaitingEpochsLeftForPublicKey(_ []byte) (uin return 0, nil } +// GetCachedEpochs - +func (ncm *NodesCoordinatorMock) GetCachedEpochs() map[uint32]struct{} { + if ncm.GetCachedEpochsCalled != nil { + return ncm.GetCachedEpochsCalled() + } + + return make(map[uint32]struct{}) +} + // IsInterfaceNil - func (ncm *NodesCoordinatorMock) IsInterfaceNil() bool { return ncm == nil diff --git a/testscommon/shardingMocks/nodesCoordinatorMocks/randomSelectorMock.go b/testscommon/shardingMocks/nodesCoordinatorMocks/randomSelectorMock.go new file mode 100644 index 00000000000..13c74dad98d --- /dev/null +++ b/testscommon/shardingMocks/nodesCoordinatorMocks/randomSelectorMock.go @@ -0,0 +1,19 @@ +package nodesCoordinatorMocks + +// RandomSelectorMock is a mock for the RandomSelector interface +type RandomSelectorMock struct { + SelectCalled func(randSeed []byte, sampleSize uint32) ([]uint32, error) +} + +// Select calls the mocked method +func (rsm *RandomSelectorMock) Select(randSeed []byte, sampleSize uint32) ([]uint32, error) { + if rsm.SelectCalled != nil { + return rsm.SelectCalled(randSeed, sampleSize) + } + return nil, nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (rsm *RandomSelectorMock) IsInterfaceNil() bool { + return rsm == nil +} diff --git a/testscommon/shardingMocks/nodesCoordinatorStub.go b/testscommon/shardingMocks/nodesCoordinatorStub.go index a142f0509ed..38be8b55a63 100644 --- a/testscommon/shardingMocks/nodesCoordinatorStub.go +++ b/testscommon/shardingMocks/nodesCoordinatorStub.go @@ -9,19 +9,23 @@ import ( // NodesCoordinatorStub - type NodesCoordinatorStub struct { - GetValidatorsPublicKeysCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) - GetValidatorsRewardsAddressesCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) - GetValidatorWithPublicKeyCalled func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) - GetAllValidatorsPublicKeysCalled func() (map[uint32][][]byte, error) - GetAllWaitingValidatorsPublicKeysCalled func(_ uint32) (map[uint32][][]byte, error) - GetAllEligibleValidatorsPublicKeysCalled func(epoch uint32) (map[uint32][][]byte, error) - ConsensusGroupSizeCalled func(shardID uint32) int - ComputeConsensusGroupCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) (validatorsGroup []nodesCoordinator.Validator, err error) - EpochStartPrepareCalled func(metaHdr data.HeaderHandler, body data.BodyHandler) - GetConsensusWhitelistedNodesCalled func(epoch uint32) (map[string]struct{}, error) - GetOwnPublicKeyCalled func() []byte - GetWaitingEpochsLeftForPublicKeyCalled func(publicKey []byte) (uint32, error) - GetNumTotalEligibleCalled func() uint64 + GetValidatorsPublicKeysCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) (string, []string, error) + GetValidatorsRewardsAddressesCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) ([]string, error) + GetValidatorWithPublicKeyCalled func(publicKey []byte) (validator nodesCoordinator.Validator, shardId uint32, err error) + GetAllValidatorsPublicKeysCalled func() (map[uint32][][]byte, error) + GetAllWaitingValidatorsPublicKeysCalled func(_ uint32) (map[uint32][][]byte, error) + GetAllEligibleValidatorsPublicKeysCalled func(epoch uint32) (map[uint32][][]byte, error) + GetAllEligibleValidatorsPublicKeysForShardCalled func(epoch uint32, shardID uint32) ([]string, error) + GetValidatorsIndexesCalled func(pubKeys []string, epoch uint32) ([]uint64, error) + ConsensusGroupSizeCalled func(shardID uint32, epoch uint32) int + ComputeConsensusGroupCalled func(randomness []byte, round uint64, shardId uint32, epoch uint32) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) + EpochStartPrepareCalled func(metaHdr data.HeaderHandler, body data.BodyHandler) + GetConsensusWhitelistedNodesCalled func(epoch uint32) (map[string]struct{}, error) + GetOwnPublicKeyCalled func() []byte + GetWaitingEpochsLeftForPublicKeyCalled func(publicKey []byte) (uint32, error) + GetNumTotalEligibleCalled func() uint64 + ShardIdForEpochCalled func(epoch uint32) (uint32, error) + GetCachedEpochsCalled func() map[uint32]struct{} } // NodesCoordinatorToRegistry - @@ -69,6 +73,14 @@ func (ncm *NodesCoordinatorStub) GetAllEligibleValidatorsPublicKeys(epoch uint32 return nil, nil } +// GetAllEligibleValidatorsPublicKeysForShard - +func (ncm *NodesCoordinatorStub) GetAllEligibleValidatorsPublicKeysForShard(epoch uint32, shardID uint32) ([]string, error) { + if ncm.GetAllEligibleValidatorsPublicKeysForShardCalled != nil { + return ncm.GetAllEligibleValidatorsPublicKeysForShardCalled(epoch, shardID) + } + return nil, nil +} + // GetAllWaitingValidatorsPublicKeys - func (ncm *NodesCoordinatorStub) GetAllWaitingValidatorsPublicKeys(epoch uint32) (map[uint32][][]byte, error) { if ncm.GetAllWaitingValidatorsPublicKeysCalled != nil { @@ -106,7 +118,10 @@ func (ncm *NodesCoordinatorStub) GetAllValidatorsPublicKeys(_ uint32) (map[uint3 } // GetValidatorsIndexes - -func (ncm *NodesCoordinatorStub) GetValidatorsIndexes(_ []string, _ uint32) ([]uint64, error) { +func (ncm *NodesCoordinatorStub) GetValidatorsIndexes(pubkeys []string, epoch uint32) ([]uint64, error) { + if ncm.GetValidatorsIndexesCalled != nil { + return ncm.GetValidatorsIndexesCalled(pubkeys, epoch) + } return nil, nil } @@ -116,20 +131,18 @@ func (ncm *NodesCoordinatorStub) ComputeConsensusGroup( round uint64, shardId uint32, epoch uint32, -) (validatorsGroup []nodesCoordinator.Validator, err error) { +) (leader nodesCoordinator.Validator, validatorsGroup []nodesCoordinator.Validator, err error) { if ncm.ComputeConsensusGroupCalled != nil { return ncm.ComputeConsensusGroupCalled(randomness, round, shardId, epoch) } - var list []nodesCoordinator.Validator - - return list, nil + return nil, nil, nil } -// ConsensusGroupSize - -func (ncm *NodesCoordinatorStub) ConsensusGroupSize(shardID uint32) int { +// ConsensusGroupSizeForShardAndEpoch - +func (ncm *NodesCoordinatorStub) ConsensusGroupSizeForShardAndEpoch(shardID uint32, epoch uint32) int { if ncm.ConsensusGroupSizeCalled != nil { - return ncm.ConsensusGroupSizeCalled(shardID) + return ncm.ConsensusGroupSizeCalled(shardID, epoch) } return 1 } @@ -140,12 +153,12 @@ func (ncm *NodesCoordinatorStub) GetConsensusValidatorsPublicKeys( round uint64, shardId uint32, epoch uint32, -) ([]string, error) { +) (string, []string, error) { if ncm.GetValidatorsPublicKeysCalled != nil { return ncm.GetValidatorsPublicKeysCalled(randomness, round, shardId, epoch) } - return nil, nil + return "", nil, nil } // SetNodesPerShards - @@ -165,8 +178,12 @@ func (ncm *NodesCoordinatorStub) GetSavedStateKey() []byte { // ShardIdForEpoch returns the nodesCoordinator configured ShardId for specified epoch if epoch configuration exists, // otherwise error -func (ncm *NodesCoordinatorStub) ShardIdForEpoch(_ uint32) (uint32, error) { - panic("not implemented") +func (ncm *NodesCoordinatorStub) ShardIdForEpoch(epoch uint32) (uint32, error) { + + if ncm.ShardIdForEpochCalled != nil { + return ncm.ShardIdForEpochCalled(epoch) + } + return 0, nil } // ShuffleOutForEpoch verifies if the shards changed in the new epoch and calls the shuffleOutHandler @@ -210,6 +227,14 @@ func (ncm *NodesCoordinatorStub) GetWaitingEpochsLeftForPublicKey(publicKey []by return 0, nil } +// GetCachedEpochs - +func (ncm *NodesCoordinatorStub) GetCachedEpochs() map[uint32]struct{} { + if ncm.GetCachedEpochsCalled != nil { + return ncm.GetCachedEpochsCalled() + } + return make(map[uint32]struct{}) +} + // IsInterfaceNil returns true if there is no value under the interface func (ncm *NodesCoordinatorStub) IsInterfaceNil() bool { return ncm == nil diff --git a/testscommon/shardingMocks/shufflerMock.go b/testscommon/shardingMocks/shufflerMock.go index 82015b638a3..a96b5ea500c 100644 --- a/testscommon/shardingMocks/shufflerMock.go +++ b/testscommon/shardingMocks/shufflerMock.go @@ -8,16 +8,6 @@ import ( type NodeShufflerMock struct { } -// UpdateParams - -func (nsm *NodeShufflerMock) UpdateParams( - _ uint32, - _ uint32, - _ float32, - _ bool, -) { - -} - // UpdateNodeLists - func (nsm *NodeShufflerMock) UpdateNodeLists(args nodesCoordinator.ArgsUpdateNodes) (*nodesCoordinator.ResUpdateNodes, error) { return &nodesCoordinator.ResUpdateNodes{ diff --git a/testscommon/state/testTrie.go b/testscommon/state/testTrie.go new file mode 100644 index 00000000000..8744009aa18 --- /dev/null +++ b/testscommon/state/testTrie.go @@ -0,0 +1,55 @@ +package state + +import ( + "fmt" + + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" + disabled2 "github.com/multiversx/mx-chain-go/common/disabled" + "github.com/multiversx/mx-chain-go/common/statistics/disabled" + "github.com/multiversx/mx-chain-go/config" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/multiversx/mx-chain-go/trie" +) + +// GetDefaultTrieParameters - +func GetDefaultTrieParameters() (common.StorageManager, marshal.Marshalizer, hashing.Hasher) { + db := testscommon.NewMemDbMock() + hasher := &hashingMocks.HasherMock{} + marshaller := &marshallerMock.MarshalizerMock{} + + tsmArgs := trie.NewTrieStorageManagerArgs{ + MainStorer: db, + Marshalizer: marshaller, + Hasher: hasher, + GeneralConfig: config.TrieStorageManagerConfig{ + SnapshotsGoroutineNum: 5, + }, + IdleProvider: disabled2.NewProcessStatusHandler(), + Identifier: "identifier", + StatsCollector: disabled.NewStateStatistics(), + } + tsm, _ := trie.NewTrieStorageManager(tsmArgs) + return tsm, marshaller, hasher +} + +// GetNewTrie - +func GetNewTrie() common.Trie { + tsm, marshaller, hasher := GetDefaultTrieParameters() + tr, _ := trie.NewTrie(tsm, marshaller, hasher, &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, 5) + + return tr +} + +// AddDataToTrie - +func AddDataToTrie(tr common.Trie, numLeaves int) { + for i := 0; i < numLeaves; i++ { + val := fmt.Sprintf("value%v", i) + _ = tr.Update([]byte(val), []byte(val)) + } + _ = tr.Commit() +} diff --git a/testscommon/stateStatisticsHandlerStub.go b/testscommon/stateStatisticsHandlerStub.go index bc13bea90d4..169d0156a91 100644 --- a/testscommon/stateStatisticsHandlerStub.go +++ b/testscommon/stateStatisticsHandlerStub.go @@ -2,20 +2,22 @@ package testscommon // StateStatisticsHandlerStub - type StateStatisticsHandlerStub struct { - ResetCalled func() - ResetSnapshotCalled func() - IncrementCacheCalled func() - CacheCalled func() uint64 - IncrementSnapshotCacheCalled func() - SnapshotCacheCalled func() uint64 - IncrementPersisterCalled func(epoch uint32) - PersisterCalled func(epoch uint32) uint64 - IncrementSnapshotPersisterCalled func(epoch uint32) - SnapshotPersisterCalled func(epoch uint32) uint64 - IncrementTrieCalled func() - TrieCalled func() uint64 - ProcessingStatsCalled func() []string - SnapshotStatsCalled func() []string + ResetCalled func() + ResetSnapshotCalled func() + IncrCacheCalled func() + CacheCalled func() uint64 + IncrSnapshotCacheCalled func() + SnapshotCacheCalled func() uint64 + IncrPersisterCalled func(epoch uint32) + IncrWritePersisterCalled func(epoch uint32) + PersisterCalled func(epoch uint32) uint64 + WritePersisterCalled func(epoch uint32) uint64 + IncrSnapshotPersisterCalled func(epoch uint32) + SnapshotPersisterCalled func(epoch uint32) uint64 + IncrTrieCalled func() + TrieCalled func() uint64 + ProcessingStatsCalled func() []string + SnapshotStatsCalled func() []string } // Reset - @@ -32,10 +34,10 @@ func (stub *StateStatisticsHandlerStub) ResetSnapshot() { } } -// IncrementCache - -func (stub *StateStatisticsHandlerStub) IncrementCache() { - if stub.IncrementCacheCalled != nil { - stub.IncrementCacheCalled() +// IncrCache - +func (stub *StateStatisticsHandlerStub) IncrCache() { + if stub.IncrCacheCalled != nil { + stub.IncrCacheCalled() } } @@ -48,10 +50,10 @@ func (stub *StateStatisticsHandlerStub) Cache() uint64 { return 0 } -// IncrementSnapshotCache - -func (stub *StateStatisticsHandlerStub) IncrementSnapshotCache() { - if stub.IncrementSnapshotCacheCalled != nil { - stub.IncrementSnapshotCacheCalled() +// IncrSnapshotCache - +func (stub *StateStatisticsHandlerStub) IncrSnapshotCache() { + if stub.IncrSnapshotCacheCalled != nil { + stub.IncrSnapshotCacheCalled() } } @@ -64,10 +66,17 @@ func (stub *StateStatisticsHandlerStub) SnapshotCache() uint64 { return 0 } -// IncrementPersister - -func (stub *StateStatisticsHandlerStub) IncrementPersister(epoch uint32) { - if stub.IncrementPersisterCalled != nil { - stub.IncrementPersisterCalled(epoch) +// IncrPersister - +func (stub *StateStatisticsHandlerStub) IncrPersister(epoch uint32) { + if stub.IncrPersisterCalled != nil { + stub.IncrPersisterCalled(epoch) + } +} + +// IncrWritePersister - +func (stub *StateStatisticsHandlerStub) IncrWritePersister(epoch uint32) { + if stub.IncrWritePersisterCalled != nil { + stub.IncrWritePersisterCalled(epoch) } } @@ -80,10 +89,19 @@ func (stub *StateStatisticsHandlerStub) Persister(epoch uint32) uint64 { return 0 } -// IncrementSnapshotPersister - -func (stub *StateStatisticsHandlerStub) IncrementSnapshotPersister(epoch uint32) { - if stub.IncrementSnapshotPersisterCalled != nil { - stub.IncrementSnapshotPersisterCalled(epoch) +// WritePersister - +func (stub *StateStatisticsHandlerStub) WritePersister(epoch uint32) uint64 { + if stub.WritePersisterCalled != nil { + return stub.WritePersisterCalled(epoch) + } + + return 0 +} + +// IncrSnapshotPersister - +func (stub *StateStatisticsHandlerStub) IncrSnapshotPersister(epoch uint32) { + if stub.IncrSnapshotPersisterCalled != nil { + stub.IncrSnapshotPersisterCalled(epoch) } } @@ -96,10 +114,10 @@ func (stub *StateStatisticsHandlerStub) SnapshotPersister(epoch uint32) uint64 { return 0 } -// IncrementTrie - -func (stub *StateStatisticsHandlerStub) IncrementTrie() { - if stub.IncrementTrieCalled != nil { - stub.IncrementTrieCalled() +// IncrTrie - +func (stub *StateStatisticsHandlerStub) IncrTrie() { + if stub.IncrTrieCalled != nil { + stub.IncrTrieCalled() } } diff --git a/trie/branchNode.go b/trie/branchNode.go index 39f8402d289..6fac5a13581 100644 --- a/trie/branchNode.go +++ b/trie/branchNode.go @@ -12,6 +12,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/trie/leavesRetriever/trieNodeData" vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) @@ -804,7 +805,7 @@ func (bn *branchNode) getAllLeavesOnChannel( continue } - clonedKeyBuilder := keyBuilder.Clone() + clonedKeyBuilder := keyBuilder.ShallowClone() clonedKeyBuilder.BuildKey([]byte{byte(i)}) err = bn.children[i].getAllLeavesOnChannel(leavesChannel, clonedKeyBuilder, trieLeafParser, db, marshalizer, chanClose, ctx) if err != nil { @@ -980,7 +981,7 @@ func (bn *branchNode) collectLeavesForMigration( return false, err } - clonedKeyBuilder := keyBuilder.Clone() + clonedKeyBuilder := keyBuilder.ShallowClone() clonedKeyBuilder.BuildKey([]byte{byte(i)}) shouldContinueMigrating, err := bn.children[i].collectLeavesForMigration(migrationArgs, db, clonedKeyBuilder) if err != nil { @@ -995,6 +996,30 @@ func (bn *branchNode) collectLeavesForMigration( return true, nil } +func (bn *branchNode) getNodeData(keyBuilder common.KeyBuilder) ([]common.TrieNodeData, error) { + err := bn.isEmptyOrNil() + if err != nil { + return nil, fmt.Errorf("getNodeData error %w", err) + } + + data := make([]common.TrieNodeData, 0) + for i := range bn.EncodedChildren { + if len(bn.EncodedChildren[i]) == 0 { + continue + } + + clonedKeyBuilder := keyBuilder.DeepClone() + clonedKeyBuilder.BuildKey([]byte{byte(i)}) + childData, err := trieNodeData.NewIntermediaryNodeData(clonedKeyBuilder, bn.EncodedChildren[i]) + if err != nil { + return nil, err + } + data = append(data, childData) + } + + return data, nil +} + // IsInterfaceNil returns true if there is no value under the interface func (bn *branchNode) IsInterfaceNil() bool { return bn == nil diff --git a/trie/branchNode_test.go b/trie/branchNode_test.go index 17e0c380d8e..687a0d01023 100644 --- a/trie/branchNode_test.go +++ b/trie/branchNode_test.go @@ -16,6 +16,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/stretchr/testify/assert" ) @@ -1504,3 +1505,55 @@ func TestBranchNode_revertChildrenVersionSliceIfNeeded(t *testing.T) { assert.Nil(t, bn.ChildrenVersion) }) } + +func TestBranchNode_getNodeData(t *testing.T) { + t.Parallel() + + t.Run("nil node", func(t *testing.T) { + t.Parallel() + + var bn *branchNode + nodeData, err := bn.getNodeData(keyBuilder.NewDisabledKeyBuilder()) + assert.Nil(t, nodeData) + assert.True(t, errors.Is(err, ErrNilBranchNode)) + }) + t.Run("gets data from all non-nil children", func(t *testing.T) { + t.Parallel() + + tr := initTrie() + _ = tr.Update([]byte("111"), []byte("111")) + _ = tr.Update([]byte("aaa"), []byte("aaa")) + _ = tr.Commit() + + bn, ok := tr.root.(*branchNode) + assert.True(t, ok) + + hashSize := 32 + keySize := 1 + kb := keyBuilder.NewKeyBuilder() + nodeData, err := bn.getNodeData(kb) + assert.Nil(t, err) + assert.Equal(t, 3, len(nodeData)) + + // branch node as child + firstChildData := nodeData[0] + assert.Equal(t, uint(1), firstChildData.GetKeyBuilder().Size()) + assert.Equal(t, bn.EncodedChildren[1], firstChildData.GetData()) + assert.Equal(t, uint64(hashSize+keySize), firstChildData.Size()) + assert.False(t, firstChildData.IsLeaf()) + + // leaf node as child + seconChildData := nodeData[1] + assert.Equal(t, uint(1), seconChildData.GetKeyBuilder().Size()) + assert.Equal(t, bn.EncodedChildren[5], seconChildData.GetData()) + assert.Equal(t, uint64(hashSize+keySize), seconChildData.Size()) + assert.False(t, seconChildData.IsLeaf()) + + // extension node as child + thirdChildData := nodeData[2] + assert.Equal(t, uint(1), thirdChildData.GetKeyBuilder().Size()) + assert.Equal(t, bn.EncodedChildren[7], thirdChildData.GetData()) + assert.Equal(t, uint64(hashSize+keySize), thirdChildData.Size()) + assert.False(t, thirdChildData.IsLeaf()) + }) +} diff --git a/trie/errors.go b/trie/errors.go index 9cc2588e501..a879fd6c94c 100644 --- a/trie/errors.go +++ b/trie/errors.go @@ -123,3 +123,9 @@ var ErrNilTrieLeafParser = errors.New("nil trie leaf parser") // ErrInvalidNodeVersion signals that an invalid node version has been provided var ErrInvalidNodeVersion = errors.New("invalid node version provided") + +// ErrEmptyInitialIteratorState signals that an empty initial iterator state was provided +var ErrEmptyInitialIteratorState = errors.New("empty initial iterator state") + +// ErrInvalidIteratorState signals that an invalid iterator state was provided +var ErrInvalidIteratorState = errors.New("invalid iterator state") diff --git a/trie/extensionNode.go b/trie/extensionNode.go index 9c05caaeebe..9ff667a7b63 100644 --- a/trie/extensionNode.go +++ b/trie/extensionNode.go @@ -14,6 +14,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/trie/leavesRetriever/trieNodeData" vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) @@ -667,7 +668,7 @@ func (en *extensionNode) getAllLeavesOnChannel( } keyBuilder.BuildKey(en.Key) - err = en.child.getAllLeavesOnChannel(leavesChannel, keyBuilder.Clone(), trieLeafParser, db, marshalizer, chanClose, ctx) + err = en.child.getAllLeavesOnChannel(leavesChannel, keyBuilder.ShallowClone(), trieLeafParser, db, marshalizer, chanClose, ctx) if err != nil { return err } @@ -784,7 +785,25 @@ func (en *extensionNode) collectLeavesForMigration( } keyBuilder.BuildKey(en.Key) - return en.child.collectLeavesForMigration(migrationArgs, db, keyBuilder.Clone()) + return en.child.collectLeavesForMigration(migrationArgs, db, keyBuilder.ShallowClone()) +} + +func (en *extensionNode) getNodeData(keyBuilder common.KeyBuilder) ([]common.TrieNodeData, error) { + err := en.isEmptyOrNil() + if err != nil { + return nil, fmt.Errorf("getNodeData error %w", err) + } + + data := make([]common.TrieNodeData, 1) + clonedKeyBuilder := keyBuilder.DeepClone() + clonedKeyBuilder.BuildKey(en.Key) + childData, err := trieNodeData.NewIntermediaryNodeData(clonedKeyBuilder, en.EncodedChild) + if err != nil { + return nil, err + } + + data[0] = childData + return data, nil } // IsInterfaceNil returns true if there is no value under the interface diff --git a/trie/extensionNode_test.go b/trie/extensionNode_test.go index 5ae78db766a..e1c099e77a4 100644 --- a/trie/extensionNode_test.go +++ b/trie/extensionNode_test.go @@ -14,6 +14,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/stretchr/testify/assert" ) @@ -1079,3 +1080,34 @@ func TestExtensionNode_getVersion(t *testing.T) { assert.Nil(t, err) }) } + +func TestExtensionNode_getNodeData(t *testing.T) { + t.Parallel() + + t.Run("nil node", func(t *testing.T) { + t.Parallel() + + var en *extensionNode + nodeData, err := en.getNodeData(keyBuilder.NewDisabledKeyBuilder()) + assert.Nil(t, nodeData) + assert.True(t, errors.Is(err, ErrNilExtensionNode)) + }) + t.Run("gets data from child", func(t *testing.T) { + t.Parallel() + + en, _ := getEnAndCollapsedEn() + en, _ = en.getCollapsedEn() + hashSize := 32 + keySize := 1 + + kb := keyBuilder.NewKeyBuilder() + nodeData, err := en.getNodeData(kb) + assert.Nil(t, err) + assert.Equal(t, 1, len(nodeData)) + + assert.Equal(t, uint(1), nodeData[0].GetKeyBuilder().Size()) + assert.Equal(t, en.EncodedChild, nodeData[0].GetData()) + assert.Equal(t, uint64(hashSize+keySize), nodeData[0].Size()) + assert.False(t, nodeData[0].IsLeaf()) + }) +} diff --git a/trie/interface.go b/trie/interface.go index da29040563b..d1d8fab1502 100644 --- a/trie/interface.go +++ b/trie/interface.go @@ -37,6 +37,7 @@ type node interface { getDirtyHashes(common.ModifiedHashes) error getChildren(db common.TrieStorageInteractor) ([]node, error) isValid() bool + getNodeData(common.KeyBuilder) ([]common.TrieNodeData, error) setDirty(bool) loadChildren(func([]byte) (node, error)) ([][]byte, []node, error) getAllLeavesOnChannel(chan core.KeyValueHolder, common.KeyBuilder, common.TrieLeafParser, common.TrieStorageInteractor, marshal.Marshalizer, chan struct{}, context.Context) error diff --git a/trie/keyBuilder/disabledKeyBuilder.go b/trie/keyBuilder/disabledKeyBuilder.go index a930f4baff1..71c2022d372 100644 --- a/trie/keyBuilder/disabledKeyBuilder.go +++ b/trie/keyBuilder/disabledKeyBuilder.go @@ -22,11 +22,26 @@ func (dkb *disabledKeyBuilder) GetKey() ([]byte, error) { return []byte{}, nil } -// Clone returns a new disabled key builder -func (dkb *disabledKeyBuilder) Clone() common.KeyBuilder { +// GetRawKey returns an empty byte array for this implementation +func (dkb *disabledKeyBuilder) GetRawKey() []byte { + return []byte{} +} + +// ShallowClone returns a new disabled key builder +func (dkb *disabledKeyBuilder) ShallowClone() common.KeyBuilder { return &disabledKeyBuilder{} } +// DeepClone returns a new disabled key builder +func (dkb *disabledKeyBuilder) DeepClone() common.KeyBuilder { + return &disabledKeyBuilder{} +} + +// Size returns 0 for this implementation +func (dkb *disabledKeyBuilder) Size() uint { + return 0 +} + // IsInterfaceNil returns true if there is no value under the interface func (dkb *disabledKeyBuilder) IsInterfaceNil() bool { return dkb == nil diff --git a/trie/keyBuilder/disabledKeyBuilder_test.go b/trie/keyBuilder/disabledKeyBuilder_test.go index cdd63acfa1f..2beb7a6cfb2 100644 --- a/trie/keyBuilder/disabledKeyBuilder_test.go +++ b/trie/keyBuilder/disabledKeyBuilder_test.go @@ -26,6 +26,7 @@ func TestDisabledKeyBuilder(t *testing.T) { require.Nil(t, err) require.True(t, bytes.Equal(key, []byte{})) - clonedBuilder := builder.Clone() + clonedBuilder := builder.ShallowClone() require.Equal(t, &disabledKeyBuilder{}, clonedBuilder) + require.Equal(t, uint(0), clonedBuilder.Size()) } diff --git a/trie/keyBuilder/keyBuilder.go b/trie/keyBuilder/keyBuilder.go index 787b1d66e0e..c1b7f78f62a 100644 --- a/trie/keyBuilder/keyBuilder.go +++ b/trie/keyBuilder/keyBuilder.go @@ -34,13 +34,28 @@ func (kb *keyBuilder) GetKey() ([]byte, error) { return hexToTrieKeyBytes(kb.key) } -// Clone returns a new KeyBuilder with the same key -func (kb *keyBuilder) Clone() common.KeyBuilder { +// GetRawKey returns the key as it is, without transforming it +func (kb *keyBuilder) GetRawKey() []byte { + return kb.key +} + +// ShallowClone returns a new KeyBuilder with the same key. The key slice points to the same memory location. +func (kb *keyBuilder) ShallowClone() common.KeyBuilder { return &keyBuilder{ key: kb.key, } } +// DeepClone returns a new KeyBuilder with the same key. This allocates a new memory location for the key slice. +func (kb *keyBuilder) DeepClone() common.KeyBuilder { + clonedKey := make([]byte, len(kb.key)) + copy(clonedKey, kb.key) + + return &keyBuilder{ + key: clonedKey, + } +} + // hexToTrieKeyBytes transforms hex nibbles into key bytes. The hex terminator is removed from the end of the hex slice, // and then the hex slice is reversed when forming the key bytes. func hexToTrieKeyBytes(hex []byte) ([]byte, error) { @@ -60,6 +75,11 @@ func hexToTrieKeyBytes(hex []byte) ([]byte, error) { return key, nil } +// Size returns the size of the key +func (kb *keyBuilder) Size() uint { + return uint(len(kb.key)) +} + // IsInterfaceNil returns true if there is no value under the interface func (kb *keyBuilder) IsInterfaceNil() bool { return kb == nil diff --git a/trie/keyBuilder/keyBuilder_test.go b/trie/keyBuilder/keyBuilder_test.go index 2dbacece385..0ec6d384ebe 100644 --- a/trie/keyBuilder/keyBuilder_test.go +++ b/trie/keyBuilder/keyBuilder_test.go @@ -11,13 +11,15 @@ func TestKeyBuilder_Clone(t *testing.T) { kb := NewKeyBuilder() kb.BuildKey([]byte("dog")) + assert.Equal(t, uint(3), kb.Size()) - clonedKb := kb.Clone() + clonedKb := kb.ShallowClone() clonedKb.BuildKey([]byte("e")) originalKey, _ := kb.GetKey() modifiedKey, _ := clonedKb.GetKey() assert.NotEqual(t, originalKey, modifiedKey) + assert.Equal(t, uint(4), clonedKb.Size()) } func TestHexToTrieKeyBytes(t *testing.T) { diff --git a/trie/leafNode.go b/trie/leafNode.go index 0b0ab6384d6..5cefe3754ff 100644 --- a/trie/leafNode.go +++ b/trie/leafNode.go @@ -15,6 +15,7 @@ import ( "github.com/multiversx/mx-chain-core-go/hashing" "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/trie/leavesRetriever/trieNodeData" vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) @@ -561,6 +562,29 @@ func (ln *leafNode) collectLeavesForMigration( return migrationArgs.TrieMigrator.AddLeafToMigrationQueue(leafData, migrationArgs.NewVersion) } +func (ln *leafNode) getNodeData(keyBuilder common.KeyBuilder) ([]common.TrieNodeData, error) { + err := ln.isEmptyOrNil() + if err != nil { + return nil, fmt.Errorf("getNodeData error %w", err) + } + + version, err := ln.getVersion() + if err != nil { + return nil, err + } + + data := make([]common.TrieNodeData, 1) + clonedKeyBuilder := keyBuilder.DeepClone() + clonedKeyBuilder.BuildKey(ln.Key) + nodeData, err := trieNodeData.NewLeafNodeData(clonedKeyBuilder, ln.Value, version) + if err != nil { + return nil, err + } + data[0] = nodeData + + return data, nil +} + // IsInterfaceNil returns true if there is no value under the interface func (ln *leafNode) IsInterfaceNil() bool { return ln == nil diff --git a/trie/leafNode_test.go b/trie/leafNode_test.go index e1e47866c8a..297d093ce94 100644 --- a/trie/leafNode_test.go +++ b/trie/leafNode_test.go @@ -13,6 +13,7 @@ import ( "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + "github.com/multiversx/mx-chain-go/trie/keyBuilder" "github.com/multiversx/mx-chain-go/trie/statistics" "github.com/stretchr/testify/assert" ) @@ -773,3 +774,32 @@ func TestLeafNode_getVersion(t *testing.T) { assert.Nil(t, err) }) } + +func TestLeafNode_getNodeData(t *testing.T) { + t.Parallel() + + t.Run("nil node", func(t *testing.T) { + t.Parallel() + + var ln *leafNode + nodeData, err := ln.getNodeData(keyBuilder.NewDisabledKeyBuilder()) + assert.Nil(t, nodeData) + assert.True(t, errors.Is(err, ErrNilLeafNode)) + }) + t.Run("gets data from child", func(t *testing.T) { + t.Parallel() + + ln := getLn(getTestMarshalizerAndHasher()) + val := []byte("dog") + + kb := keyBuilder.NewKeyBuilder() + nodeData, err := ln.getNodeData(kb) + assert.Nil(t, err) + assert.Equal(t, 1, len(nodeData)) + + assert.Equal(t, uint(3), nodeData[0].GetKeyBuilder().Size()) + assert.Equal(t, val, nodeData[0].GetData()) + assert.Equal(t, uint64(len(val)+len(val)), nodeData[0].Size()) + assert.True(t, nodeData[0].IsLeaf()) + }) +} diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go new file mode 100644 index 00000000000..f428c4d9101 --- /dev/null +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator.go @@ -0,0 +1,157 @@ +package dfsTrieIterator + +import ( + "context" + "encoding/hex" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/trie" + "github.com/multiversx/mx-chain-go/trie/keyBuilder" + "github.com/multiversx/mx-chain-go/trie/leavesRetriever/trieNodeData" +) + +type dfsIterator struct { + nextNodes []common.TrieNodeData + db common.TrieStorageInteractor + marshaller marshal.Marshalizer + hasher hashing.Hasher +} + +// NewIterator creates a new DFS iterator for the trie. +func NewIterator(initialState [][]byte, db common.TrieStorageInteractor, marshaller marshal.Marshalizer, hasher hashing.Hasher) (*dfsIterator, error) { + if check.IfNil(db) { + return nil, trie.ErrNilDatabase + } + if check.IfNil(marshaller) { + return nil, trie.ErrNilMarshalizer + } + if check.IfNil(hasher) { + return nil, trie.ErrNilHasher + } + if len(initialState) == 0 { + return nil, trie.ErrEmptyInitialIteratorState + } + + nextNodes, err := getNextNodesFromInitialState(initialState, uint(hasher.Size())) + if err != nil { + return nil, err + } + + return &dfsIterator{ + nextNodes: nextNodes, + db: db, + marshaller: marshaller, + hasher: hasher, + }, nil +} + +func getNextNodesFromInitialState(initialState [][]byte, hashSize uint) ([]common.TrieNodeData, error) { + nextNodes := make([]common.TrieNodeData, len(initialState)) + for i, state := range initialState { + if len(state) < int(hashSize) { + return nil, trie.ErrInvalidIteratorState + } + + nodeHash := state[:hashSize] + key := state[hashSize:] + + kb := keyBuilder.NewKeyBuilder() + kb.BuildKey(key) + nodeData, err := trieNodeData.NewIntermediaryNodeData(kb, nodeHash) + if err != nil { + return nil, err + } + nextNodes[i] = nodeData + } + + return nextNodes, nil +} + +func getIteratorStateFromNextNodes(nextNodes []common.TrieNodeData) [][]byte { + iteratorState := make([][]byte, len(nextNodes)) + for i, node := range nextNodes { + nodeHash := node.GetData() + key := node.GetKeyBuilder().GetRawKey() + + iteratorState[i] = append(nodeHash, key...) + } + + return iteratorState +} + +// GetLeaves retrieves leaves from the trie. It stops either when the number of leaves is reached or the context is done. +func (it *dfsIterator) GetLeaves(numLeaves int, maxSize uint64, leavesParser common.TrieLeafParser, ctx context.Context) (map[string]string, error) { + retrievedLeaves := make(map[string]string) + leavesSize := uint64(0) + for { + nextNodes := make([]common.TrieNodeData, 0) + if leavesSize >= maxSize { + return retrievedLeaves, nil + } + + if len(retrievedLeaves) >= numLeaves && numLeaves != 0 { + return retrievedLeaves, nil + } + + if it.FinishedIteration() { + return retrievedLeaves, nil + } + + if common.IsContextDone(ctx) { + return retrievedLeaves, nil + } + + nextNode := it.nextNodes[0] + nodeHash := nextNode.GetData() + childrenNodes, err := trie.GetNodeDataFromHash(nodeHash, nextNode.GetKeyBuilder(), it.db, it.marshaller, it.hasher) + if err != nil { + return nil, err + } + + for _, childNode := range childrenNodes { + if childNode.IsLeaf() { + key, err := childNode.GetKeyBuilder().GetKey() + if err != nil { + return nil, err + } + + keyValHolder, err := leavesParser.ParseLeaf(key, childNode.GetData(), childNode.GetVersion()) + if err != nil { + return nil, err + } + + hexKey := hex.EncodeToString(keyValHolder.Key()) + hexData := hex.EncodeToString(keyValHolder.Value()) + retrievedLeaves[hexKey] = hexData + leavesSize += uint64(len(hexKey) + len(hexData)) + continue + } + + nextNodes = append(nextNodes, childNode) + } + + it.nextNodes = append(nextNodes, it.nextNodes[1:]...) + } +} + +// GetIteratorState returns the state of the iterator from which it can be resumed by another call. +func (it *dfsIterator) GetIteratorState() [][]byte { + if it.FinishedIteration() { + return nil + } + + return getIteratorStateFromNextNodes(it.nextNodes) +} + +// FinishedIteration checks if the iterator has finished the iteration. +func (it *dfsIterator) FinishedIteration() bool { + return len(it.nextNodes) == 0 +} + +// IsInterfaceNil returns true if there is no value under the interface +func (it *dfsIterator) IsInterfaceNil() bool { + return it == nil +} diff --git a/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go new file mode 100644 index 00000000000..3857ddcd9e1 --- /dev/null +++ b/trie/leavesRetriever/dfsTrieIterator/dfsTrieIterator_test.go @@ -0,0 +1,270 @@ +package dfsTrieIterator + +import ( + "bytes" + "context" + "encoding/hex" + "math" + "testing" + + "github.com/multiversx/mx-chain-go/state/parsers" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + trieTest "github.com/multiversx/mx-chain-go/testscommon/state" + "github.com/multiversx/mx-chain-go/testscommon/storageManager" + "github.com/multiversx/mx-chain-go/trie" + "github.com/stretchr/testify/assert" +) + +var maxSize = uint64(math.MaxUint64) + +func TestNewIterator(t *testing.T) { + t.Parallel() + + t.Run("nil db", func(t *testing.T) { + t.Parallel() + + iterator, err := NewIterator([][]byte{[]byte("initial"), []byte("state")}, nil, &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}) + assert.Nil(t, iterator) + assert.Equal(t, trie.ErrNilDatabase, err) + }) + t.Run("nil marshaller", func(t *testing.T) { + t.Parallel() + + iterator, err := NewIterator([][]byte{[]byte("initial"), []byte("state")}, testscommon.NewMemDbMock(), nil, &hashingMocks.HasherMock{}) + assert.Nil(t, iterator) + assert.Equal(t, trie.ErrNilMarshalizer, err) + }) + t.Run("nil hasher", func(t *testing.T) { + t.Parallel() + + iterator, err := NewIterator([][]byte{[]byte("initial"), []byte("state")}, testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, nil) + assert.Nil(t, iterator) + assert.Equal(t, trie.ErrNilHasher, err) + }) + t.Run("empty initial state", func(t *testing.T) { + t.Parallel() + + iterator, err := NewIterator([][]byte{}, testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}) + assert.Nil(t, iterator) + assert.Equal(t, trie.ErrEmptyInitialIteratorState, err) + }) + t.Run("invalid initial state", func(t *testing.T) { + t.Parallel() + + iterator, err := NewIterator([][]byte{[]byte("invalid state")}, testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}) + assert.Nil(t, iterator) + assert.Equal(t, trie.ErrInvalidIteratorState, err) + }) + t.Run("should work", func(t *testing.T) { + t.Parallel() + + initialState := [][]byte{ + bytes.Repeat([]byte{0}, 40), + bytes.Repeat([]byte{1}, 40), + } + + db, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, err := NewIterator(initialState, db, marshaller, hasher) + assert.Nil(t, err) + + assert.Equal(t, 2, len(iterator.nextNodes)) + }) +} + +func TestDfsIterator_GetLeaves(t *testing.T) { + t.Parallel() + + t.Run("context done returns retrieved leaves and saves iterator state", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + numLeaves := 25 + expectedNumLeaves := 9 + numGetCalls := 0 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() + + ctx, cancel := context.WithCancel(context.Background()) + + trieStorage := tr.GetStorageManager() + dbWrapper := &storageManager.StorageManagerStub{ + GetCalled: func(key []byte) ([]byte, error) { + if numGetCalls == 15 { + cancel() + } + numGetCalls++ + return trieStorage.Get(key) + }, + PutCalled: func(key, data []byte) error { + return trieStorage.Put(key, data) + }, + } + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator([][]byte{rootHash}, dbWrapper, marshaller, hasher) + + trieData, err := iterator.GetLeaves(numLeaves, maxSize, parsers.NewMainTrieLeafParser(), ctx) + assert.Nil(t, err) + assert.Equal(t, expectedNumLeaves, len(trieData)) + }) + t.Run("finishes iteration returns retrieved leaves", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + numLeaves := 25 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() + + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) + + trieData, err := iterator.GetLeaves(numLeaves, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) + assert.Nil(t, err) + assert.Equal(t, numLeaves, len(trieData)) + }) + t.Run("num leaves reached returns retrieved leaves and saves iterator context", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + numLeaves := 25 + expectedNumRetrievedLeaves := 17 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() + + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) + + trieData, err := iterator.GetLeaves(17, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) + assert.Nil(t, err) + assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) + }) + t.Run("num leaves 0 iterates until maxSize reached", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + numLeaves := 25 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() + + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) + + trieData, err := iterator.GetLeaves(0, 200, parsers.NewMainTrieLeafParser(), context.Background()) + assert.Nil(t, err) + assert.Equal(t, 8, len(trieData)) + assert.Equal(t, 8, len(iterator.nextNodes)) + }) + t.Run("max size reached returns retrieved leaves and saves iterator context", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + numLeaves := 25 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() + + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) + + iteratorMaxSize := uint64(200) + trieData, err := iterator.GetLeaves(numLeaves, iteratorMaxSize, parsers.NewMainTrieLeafParser(), context.Background()) + assert.Nil(t, err) + assert.Equal(t, 8, len(trieData)) + assert.Equal(t, 8, len(iterator.nextNodes)) + }) + t.Run("retrieve all leaves in multiple calls", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + numLeaves := 25 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) + + numRetrievedLeaves := 0 + numIterations := 0 + for numRetrievedLeaves < numLeaves { + trieData, err := iterator.GetLeaves(5, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) + assert.Nil(t, err) + + numRetrievedLeaves += len(trieData) + numIterations++ + } + + assert.Equal(t, numLeaves, numRetrievedLeaves) + assert.Equal(t, 5, numIterations) + }) + t.Run("retrieve leaves with nil context does not panic", func(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + numLeaves := 25 + expectedNumRetrievedLeaves := 0 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() + + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) + + trieData, err := iterator.GetLeaves(numLeaves, maxSize, parsers.NewMainTrieLeafParser(), nil) + assert.Nil(t, err) + assert.Equal(t, expectedNumRetrievedLeaves, len(trieData)) + }) +} + +func TestDfsIterator_GetIteratorState(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + _ = tr.Update([]byte("doe"), []byte("reindeer")) + _ = tr.Update([]byte("dog"), []byte("puppy")) + _ = tr.Update([]byte("ddog"), []byte("cat")) + _ = tr.Commit() + rootHash, _ := tr.RootHash() + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) + + leaves, err := iterator.GetLeaves(2, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) + assert.Nil(t, err) + assert.Equal(t, 2, len(leaves)) + val, ok := leaves[hex.EncodeToString([]byte("doe"))] + assert.True(t, ok) + assert.Equal(t, hex.EncodeToString([]byte("reindeer")), val) + val, ok = leaves[hex.EncodeToString([]byte("ddog"))] + assert.True(t, ok) + assert.Equal(t, hex.EncodeToString([]byte("cat")), val) + + iteratorState := iterator.GetIteratorState() + assert.Equal(t, 1, len(iteratorState)) + hash := iteratorState[0][:hasher.Size()] + key := iteratorState[0][hasher.Size():] + assert.Equal(t, []byte{0x7, 0x6, 0xf, 0x6, 0x4, 0x6, 0x10}, key) + leafBytes, err := tr.GetStorageManager().Get(hash) + assert.Nil(t, err) + assert.NotNil(t, leafBytes) +} + +func TestDfsIterator_FinishedIteration(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + numLeaves := 25 + trieTest.AddDataToTrie(tr, numLeaves) + rootHash, _ := tr.RootHash() + _, marshaller, hasher := trieTest.GetDefaultTrieParameters() + iterator, _ := NewIterator([][]byte{rootHash}, tr.GetStorageManager(), marshaller, hasher) + + numRetrievedLeaves := 0 + for numRetrievedLeaves < numLeaves { + assert.False(t, iterator.FinishedIteration()) + trieData, err := iterator.GetLeaves(5, maxSize, parsers.NewMainTrieLeafParser(), context.Background()) + assert.Nil(t, err) + + numRetrievedLeaves += len(trieData) + } + + assert.Equal(t, numLeaves, numRetrievedLeaves) + assert.True(t, iterator.FinishedIteration()) +} diff --git a/trie/leavesRetriever/disabledLeavesRetriever.go b/trie/leavesRetriever/disabledLeavesRetriever.go new file mode 100644 index 00000000000..8d3d33720ba --- /dev/null +++ b/trie/leavesRetriever/disabledLeavesRetriever.go @@ -0,0 +1,24 @@ +package leavesRetriever + +import ( + "context" + + "github.com/multiversx/mx-chain-go/common" +) + +type disabledLeavesRetriever struct{} + +// NewDisabledLeavesRetriever creates a new disabled leaves retriever +func NewDisabledLeavesRetriever() *disabledLeavesRetriever { + return &disabledLeavesRetriever{} +} + +// GetLeaves returns an empty map and a nil byte slice for this implementation +func (dlr *disabledLeavesRetriever) GetLeaves(_ int, _ [][]byte, _ common.TrieLeafParser, _ context.Context) (map[string]string, [][]byte, error) { + return make(map[string]string), [][]byte{}, nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (dlr *disabledLeavesRetriever) IsInterfaceNil() bool { + return dlr == nil +} diff --git a/trie/leavesRetriever/errors.go b/trie/leavesRetriever/errors.go new file mode 100644 index 00000000000..e1f8aa6c11a --- /dev/null +++ b/trie/leavesRetriever/errors.go @@ -0,0 +1,15 @@ +package leavesRetriever + +import "errors" + +// ErrNilDB is returned when the given db is nil +var ErrNilDB = errors.New("nil db") + +// ErrNilMarshaller is returned when the given marshaller is nil +var ErrNilMarshaller = errors.New("nil marshaller") + +// ErrNilHasher is returned when the given hasher is nil +var ErrNilHasher = errors.New("nil hasher") + +// ErrIteratorNotFound is returned when the iterator is not found +var ErrIteratorNotFound = errors.New("iterator not found") diff --git a/trie/leavesRetriever/leavesRetriever.go b/trie/leavesRetriever/leavesRetriever.go new file mode 100644 index 00000000000..a2822975f43 --- /dev/null +++ b/trie/leavesRetriever/leavesRetriever.go @@ -0,0 +1,64 @@ +package leavesRetriever + +import ( + "context" + "fmt" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/hashing" + "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/trie/leavesRetriever/dfsTrieIterator" +) + +type leavesRetriever struct { + db common.TrieStorageInteractor + marshaller marshal.Marshalizer + hasher hashing.Hasher + maxSize uint64 +} + +// NewLeavesRetriever creates a new leaves retriever +func NewLeavesRetriever(db common.TrieStorageInteractor, marshaller marshal.Marshalizer, hasher hashing.Hasher, maxSize uint64) (*leavesRetriever, error) { + if check.IfNil(db) { + return nil, ErrNilDB + } + if check.IfNil(marshaller) { + return nil, ErrNilMarshaller + } + if check.IfNil(hasher) { + return nil, ErrNilHasher + } + + return &leavesRetriever{ + db: db, + marshaller: marshaller, + hasher: hasher, + maxSize: maxSize, + }, nil +} + +// GetLeaves retrieves leaves from the trie starting from the iterator state. It will also return the new iterator state +// from which one can continue the iteration. +func (lr *leavesRetriever) GetLeaves(numLeaves int, iteratorState [][]byte, leavesParser common.TrieLeafParser, ctx context.Context) (map[string]string, [][]byte, error) { + if check.IfNil(leavesParser) { + return nil, nil, fmt.Errorf("nil leaves parser") + } + + iterator, err := dfsTrieIterator.NewIterator(iteratorState, lr.db, lr.marshaller, lr.hasher) + if err != nil { + return nil, nil, err + } + + leavesData, err := iterator.GetLeaves(numLeaves, lr.maxSize, leavesParser, ctx) + if err != nil { + return nil, nil, err + } + + return leavesData, iterator.GetIteratorState(), nil +} + +// IsInterfaceNil returns true if there is no value under the interface +func (lr *leavesRetriever) IsInterfaceNil() bool { + return lr == nil +} diff --git a/trie/leavesRetriever/leavesRetriever_test.go b/trie/leavesRetriever/leavesRetriever_test.go new file mode 100644 index 00000000000..8fd376de439 --- /dev/null +++ b/trie/leavesRetriever/leavesRetriever_test.go @@ -0,0 +1,78 @@ +package leavesRetriever_test + +import ( + "context" + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-core-go/core/keyValStorage" + "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" + "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" + trieTest "github.com/multiversx/mx-chain-go/testscommon/state" + trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" + "github.com/multiversx/mx-chain-go/trie/leavesRetriever" + "github.com/stretchr/testify/assert" +) + +func TestNewLeavesRetriever(t *testing.T) { + t.Parallel() + + t.Run("nil db", func(t *testing.T) { + t.Parallel() + + lr, err := leavesRetriever.NewLeavesRetriever(nil, &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 100) + assert.Nil(t, lr) + assert.Equal(t, leavesRetriever.ErrNilDB, err) + }) + t.Run("nil marshaller", func(t *testing.T) { + t.Parallel() + + lr, err := leavesRetriever.NewLeavesRetriever(testscommon.NewMemDbMock(), nil, &hashingMocks.HasherMock{}, 100) + assert.Nil(t, lr) + assert.Equal(t, leavesRetriever.ErrNilMarshaller, err) + }) + t.Run("nil hasher", func(t *testing.T) { + t.Parallel() + + lr, err := leavesRetriever.NewLeavesRetriever(testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, nil, 100) + assert.Nil(t, lr) + assert.Equal(t, leavesRetriever.ErrNilHasher, err) + }) + t.Run("new leaves retriever", func(t *testing.T) { + t.Parallel() + + var lr common.TrieLeavesRetriever + assert.True(t, check.IfNil(lr)) + + lr, err := leavesRetriever.NewLeavesRetriever(testscommon.NewMemDbMock(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 100) + assert.Nil(t, err) + assert.False(t, check.IfNil(lr)) + }) +} + +func TestLeavesRetriever_GetLeaves(t *testing.T) { + t.Parallel() + + tr := trieTest.GetNewTrie() + trieTest.AddDataToTrie(tr, 25) + rootHash, _ := tr.RootHash() + leafParser := &trieMock.TrieLeafParserStub{ + ParseLeafCalled: func(key []byte, val []byte, version core.TrieNodeVersion) (core.KeyValueHolder, error) { + return keyValStorage.NewKeyValStorage(key, val), nil + }, + } + lr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 100000) + leaves, newIteratorState, err := lr.GetLeaves(10, [][]byte{rootHash}, leafParser, context.Background()) + assert.Nil(t, err) + assert.Equal(t, 10, len(leaves)) + assert.Equal(t, 8, len(newIteratorState)) + + newLr, _ := leavesRetriever.NewLeavesRetriever(tr.GetStorageManager(), &marshallerMock.MarshalizerMock{}, &hashingMocks.HasherMock{}, 100000) + leaves, newIteratorState, err = newLr.GetLeaves(10, newIteratorState, leafParser, context.Background()) + assert.Nil(t, err) + assert.Equal(t, 10, len(leaves)) + assert.Equal(t, 3, len(newIteratorState)) +} diff --git a/trie/leavesRetriever/trieNodeData/baseNodeData.go b/trie/leavesRetriever/trieNodeData/baseNodeData.go new file mode 100644 index 00000000000..d4c11f025d0 --- /dev/null +++ b/trie/leavesRetriever/trieNodeData/baseNodeData.go @@ -0,0 +1,31 @@ +package trieNodeData + +import ( + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" +) + +type baseNodeData struct { + keyBuilder common.KeyBuilder + data []byte +} + +// GetData returns the bytes stored in the data field +func (bnd *baseNodeData) GetData() []byte { + return bnd.data +} + +// GetKeyBuilder returns the keyBuilder +func (bnd *baseNodeData) GetKeyBuilder() common.KeyBuilder { + return bnd.keyBuilder +} + +// Size returns the size of the data field combined with the keyBuilder size +func (bnd *baseNodeData) Size() uint64 { + keyBuilderSize := uint(0) + if !check.IfNil(bnd.keyBuilder) { + keyBuilderSize = bnd.keyBuilder.Size() + } + + return uint64(len(bnd.data)) + uint64(keyBuilderSize) +} diff --git a/trie/leavesRetriever/trieNodeData/baseNodeData_test.go b/trie/leavesRetriever/trieNodeData/baseNodeData_test.go new file mode 100644 index 00000000000..7935c95d95d --- /dev/null +++ b/trie/leavesRetriever/trieNodeData/baseNodeData_test.go @@ -0,0 +1,37 @@ +package trieNodeData + +import ( + "testing" + + "github.com/multiversx/mx-chain-go/trie/keyBuilder" + "github.com/stretchr/testify/assert" +) + +func TestBaseNodeData(t *testing.T) { + t.Parallel() + + t.Run("empty base node data", func(t *testing.T) { + t.Parallel() + + bnd := &baseNodeData{} + assert.Nil(t, bnd.GetData()) + assert.Nil(t, bnd.GetKeyBuilder()) + assert.Equal(t, uint64(0), bnd.Size()) + }) + t.Run("base node data with data", func(t *testing.T) { + t.Parallel() + + data := []byte("data") + key := []byte("key") + kb := keyBuilder.NewKeyBuilder() + kb.BuildKey(key) + bnd := &baseNodeData{ + data: data, + keyBuilder: kb, + } + + assert.Equal(t, data, bnd.GetData()) + assert.Equal(t, kb, bnd.GetKeyBuilder()) + assert.Equal(t, uint64(len(data)+len(key)), bnd.Size()) + }) +} diff --git a/trie/leavesRetriever/trieNodeData/errors.go b/trie/leavesRetriever/trieNodeData/errors.go new file mode 100644 index 00000000000..7d04c81e598 --- /dev/null +++ b/trie/leavesRetriever/trieNodeData/errors.go @@ -0,0 +1,6 @@ +package trieNodeData + +import "errors" + +// ErrNilKeyBuilder is returned when the given key builder is nil +var ErrNilKeyBuilder = errors.New("nil key builder") diff --git a/trie/leavesRetriever/trieNodeData/intermediaryNodeData.go b/trie/leavesRetriever/trieNodeData/intermediaryNodeData.go new file mode 100644 index 00000000000..10a18a856fa --- /dev/null +++ b/trie/leavesRetriever/trieNodeData/intermediaryNodeData.go @@ -0,0 +1,40 @@ +package trieNodeData + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" +) + +type intermediaryNodeData struct { + *baseNodeData +} + +// NewIntermediaryNodeData creates a new intermediary node data +func NewIntermediaryNodeData(key common.KeyBuilder, data []byte) (*intermediaryNodeData, error) { + if check.IfNil(key) { + return nil, ErrNilKeyBuilder + } + + return &intermediaryNodeData{ + baseNodeData: &baseNodeData{ + keyBuilder: key, + data: data, + }, + }, nil +} + +// IsLeaf returns false +func (ind *intermediaryNodeData) IsLeaf() bool { + return false +} + +// GetVersion returns NotSpecified +func (ind *intermediaryNodeData) GetVersion() core.TrieNodeVersion { + return core.NotSpecified +} + +// IsInterfaceNil returns true if there is no value under the interface +func (ind *intermediaryNodeData) IsInterfaceNil() bool { + return ind == nil +} diff --git a/trie/leavesRetriever/trieNodeData/intermediaryNodeData_test.go b/trie/leavesRetriever/trieNodeData/intermediaryNodeData_test.go new file mode 100644 index 00000000000..2511469c102 --- /dev/null +++ b/trie/leavesRetriever/trieNodeData/intermediaryNodeData_test.go @@ -0,0 +1,31 @@ +package trieNodeData + +import ( + "testing" + + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/trie/keyBuilder" + "github.com/stretchr/testify/assert" +) + +func TestNewIntermediaryNodeData(t *testing.T) { + t.Parallel() + + var ind *intermediaryNodeData + assert.True(t, check.IfNil(ind)) + + ind, err := NewIntermediaryNodeData(nil, nil) + assert.Equal(t, ErrNilKeyBuilder, err) + assert.True(t, check.IfNil(ind)) + + ind, err = NewIntermediaryNodeData(keyBuilder.NewKeyBuilder(), []byte("data")) + assert.Nil(t, err) + assert.False(t, check.IfNil(ind)) +} + +func TestIntermediaryNodeData(t *testing.T) { + t.Parallel() + + ind, _ := NewIntermediaryNodeData(keyBuilder.NewKeyBuilder(), []byte("data")) + assert.False(t, ind.IsLeaf()) +} diff --git a/trie/leavesRetriever/trieNodeData/leafNodeData.go b/trie/leavesRetriever/trieNodeData/leafNodeData.go new file mode 100644 index 00000000000..08a80f2d3f8 --- /dev/null +++ b/trie/leavesRetriever/trieNodeData/leafNodeData.go @@ -0,0 +1,42 @@ +package trieNodeData + +import ( + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/common" +) + +type leafNodeData struct { + *baseNodeData + version core.TrieNodeVersion +} + +// NewLeafNodeData creates a new leaf node data +func NewLeafNodeData(key common.KeyBuilder, data []byte, version core.TrieNodeVersion) (*leafNodeData, error) { + if check.IfNil(key) { + return nil, ErrNilKeyBuilder + } + + return &leafNodeData{ + baseNodeData: &baseNodeData{ + keyBuilder: key, + data: data, + }, + version: version, + }, nil +} + +// IsLeaf returns true +func (lnd *leafNodeData) IsLeaf() bool { + return true +} + +// GetVersion returns the version of the leaf +func (lnd *leafNodeData) GetVersion() core.TrieNodeVersion { + return lnd.version +} + +// IsInterfaceNil returns true if there is no value under the interface +func (lnd *leafNodeData) IsInterfaceNil() bool { + return lnd == nil +} diff --git a/trie/leavesRetriever/trieNodeData/leafNodeData_test.go b/trie/leavesRetriever/trieNodeData/leafNodeData_test.go new file mode 100644 index 00000000000..dc4b4ab656b --- /dev/null +++ b/trie/leavesRetriever/trieNodeData/leafNodeData_test.go @@ -0,0 +1,32 @@ +package trieNodeData + +import ( + "testing" + + "github.com/multiversx/mx-chain-core-go/core" + "github.com/multiversx/mx-chain-core-go/core/check" + "github.com/multiversx/mx-chain-go/trie/keyBuilder" + "github.com/stretchr/testify/assert" +) + +func TestNewLeafNodeData(t *testing.T) { + t.Parallel() + + var lnd *leafNodeData + assert.True(t, check.IfNil(lnd)) + + lnd, err := NewLeafNodeData(nil, nil, core.NotSpecified) + assert.Equal(t, ErrNilKeyBuilder, err) + assert.True(t, check.IfNil(lnd)) + + lnd, err = NewLeafNodeData(keyBuilder.NewKeyBuilder(), []byte("data"), core.NotSpecified) + assert.Nil(t, err) + assert.False(t, check.IfNil(lnd)) +} + +func TestLeafNodeData(t *testing.T) { + t.Parallel() + + lnd, _ := NewLeafNodeData(keyBuilder.NewKeyBuilder(), []byte("data"), core.NotSpecified) + assert.True(t, lnd.IsLeaf()) +} diff --git a/trie/mock/keyBuilderStub.go b/trie/mock/keyBuilderStub.go index 8ba29de2213..7fec902d542 100644 --- a/trie/mock/keyBuilderStub.go +++ b/trie/mock/keyBuilderStub.go @@ -4,9 +4,12 @@ import "github.com/multiversx/mx-chain-go/common" // KeyBuilderStub - type KeyBuilderStub struct { - BuildKeyCalled func(keyPart []byte) - GetKeyCalled func() ([]byte, error) - CloneCalled func() common.KeyBuilder + BuildKeyCalled func(keyPart []byte) + GetKeyCalled func() ([]byte, error) + GetRawKeyCalled func() []byte + ShallowCloneCalled func() common.KeyBuilder + DeepCloneCalled func() common.KeyBuilder + SizeCalled func() uint } // BuildKey - @@ -25,15 +28,42 @@ func (stub *KeyBuilderStub) GetKey() ([]byte, error) { return []byte{}, nil } -// Clone - -func (stub *KeyBuilderStub) Clone() common.KeyBuilder { - if stub.CloneCalled != nil { - return stub.CloneCalled() +// GetRawKey - +func (stub *KeyBuilderStub) GetRawKey() []byte { + if stub.GetRawKeyCalled != nil { + return stub.GetRawKeyCalled() + } + + return []byte{} +} + +// ShallowClone - +func (stub *KeyBuilderStub) ShallowClone() common.KeyBuilder { + if stub.ShallowCloneCalled != nil { + return stub.ShallowCloneCalled() } return &KeyBuilderStub{} } +// DeepClone - +func (stub *KeyBuilderStub) DeepClone() common.KeyBuilder { + if stub.DeepCloneCalled != nil { + return stub.DeepCloneCalled() + } + + return &KeyBuilderStub{} +} + +// Size - +func (stub *KeyBuilderStub) Size() uint { + if stub.SizeCalled != nil { + return stub.SizeCalled() + } + + return 0 +} + // IsInterfaceNil - func (stub *KeyBuilderStub) IsInterfaceNil() bool { return stub == nil diff --git a/trie/node.go b/trie/node.go index 754b3b3548d..63f7a69c1e1 100644 --- a/trie/node.go +++ b/trie/node.go @@ -152,7 +152,7 @@ func resolveIfCollapsed(n node, pos byte, db common.TrieStorageInteractor) error func handleStorageInteractorStats(db common.TrieStorageInteractor) { if db != nil { - db.GetStateStatsHandler().IncrementTrie() + db.GetStateStatsHandler().IncrTrie() } } @@ -180,7 +180,7 @@ func hasValidHash(n node) (bool, error) { } func decodeNode(encNode []byte, marshalizer marshal.Marshalizer, hasher hashing.Hasher) (node, error) { - if encNode == nil || len(encNode) < 1 { + if len(encNode) < 1 { return nil, ErrInvalidEncoding } diff --git a/trie/patriciaMerkleTrie.go b/trie/patriciaMerkleTrie.go index da9eb87a65f..ed92942eabe 100644 --- a/trie/patriciaMerkleTrie.go +++ b/trie/patriciaMerkleTrie.go @@ -725,6 +725,16 @@ func (tr *patriciaMerkleTrie) IsMigratedToLatestVersion() (bool, error) { return version == versionForNewlyAddedData, nil } +// GetNodeDataFromHash returns the node data for the given hash +func GetNodeDataFromHash(hash []byte, keyBuilder common.KeyBuilder, db common.TrieStorageInteractor, msh marshal.Marshalizer, hsh hashing.Hasher) ([]common.TrieNodeData, error) { + n, err := getNodeFromDBAndDecode(hash, db, msh, hsh) + if err != nil { + return nil, err + } + + return n.getNodeData(keyBuilder) +} + // Close stops all the active goroutines started by the trie func (tr *patriciaMerkleTrie) Close() error { tr.mutOperation.Lock() diff --git a/trie/patriciaMerkleTrie_test.go b/trie/patriciaMerkleTrie_test.go index 8a02a8edcd9..76f4e34e230 100644 --- a/trie/patriciaMerkleTrie_test.go +++ b/trie/patriciaMerkleTrie_test.go @@ -26,6 +26,7 @@ import ( "github.com/multiversx/mx-chain-go/common/holders" errorsCommon "github.com/multiversx/mx-chain-go/errors" "github.com/multiversx/mx-chain-go/state/parsers" + "github.com/multiversx/mx-chain-go/testscommon" "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/storageManager" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" @@ -626,7 +627,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { keyBuilderStub.GetKeyCalled = func() ([]byte, error) { return nil, expectedErr } - keyBuilderStub.CloneCalled = func() common.KeyBuilder { + keyBuilderStub.ShallowCloneCalled = func() common.KeyBuilder { return keyBuilderStub } @@ -668,7 +669,7 @@ func TestPatriciaMerkleTrie_GetAllLeavesOnChannel(t *testing.T) { } return nil, expectedErr } - keyBuilderStub.CloneCalled = func() common.KeyBuilder { + keyBuilderStub.ShallowCloneCalled = func() common.KeyBuilder { return keyBuilderStub } @@ -1515,6 +1516,39 @@ func TestPatriciaMerkleTrie_IsMigrated(t *testing.T) { }) } +func TestGetNodeDataFromHash(t *testing.T) { + t.Parallel() + + tr := initTrie() + _ = tr.Update([]byte("111"), []byte("111")) + _ = tr.Update([]byte("aaa"), []byte("aaa")) + _ = tr.Commit() + + hashSize := 32 + keySize := 1 + + rootHash, _ := tr.RootHash() + nodeData, err := trie.GetNodeDataFromHash(rootHash, keyBuilder.NewKeyBuilder(), tr.GetStorageManager(), &marshal.GogoProtoMarshalizer{}, &testscommon.KeccakMock{}) + assert.Nil(t, err) + assert.Equal(t, 3, len(nodeData)) + + firstChildData := nodeData[0] + assert.Equal(t, uint(1), firstChildData.GetKeyBuilder().Size()) + assert.Equal(t, uint64(hashSize+keySize), firstChildData.Size()) + assert.False(t, firstChildData.IsLeaf()) + + seconChildData := nodeData[1] + assert.Equal(t, uint(1), seconChildData.GetKeyBuilder().Size()) + assert.Equal(t, uint64(hashSize+keySize), seconChildData.Size()) + assert.False(t, seconChildData.IsLeaf()) + + thirdChildData := nodeData[2] + assert.Equal(t, uint(1), thirdChildData.GetKeyBuilder().Size()) + assert.Equal(t, uint64(hashSize+keySize), thirdChildData.Size()) + assert.False(t, thirdChildData.IsLeaf()) + +} + func BenchmarkPatriciaMerkleTree_Insert(b *testing.B) { tr := emptyTrie() hsh := keccak.NewKeccak() diff --git a/trie/sync_test.go b/trie/sync_test.go index ab5083eb85a..7d6c26b3ba5 100644 --- a/trie/sync_test.go +++ b/trie/sync_test.go @@ -10,14 +10,16 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" "github.com/multiversx/mx-chain-go/testscommon/hashingMocks" "github.com/multiversx/mx-chain-go/testscommon/marshallerMock" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/trie/statistics" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockArgument(timeout time.Duration) ArgTrieSyncer { @@ -32,7 +34,7 @@ func createMockArgument(timeout time.Duration) ArgTrieSyncer { return ArgTrieSyncer{ RequestHandler: &testscommon.RequestHandlerStub{}, - InterceptedNodes: testscommon.NewCacherMock(), + InterceptedNodes: cache.NewCacherMock(), DB: trieStorage, Hasher: &hashingMocks.HasherMock{}, Marshalizer: &marshallerMock.MarshalizerMock{}, diff --git a/update/factory/exportHandlerFactory.go b/update/factory/exportHandlerFactory.go index c13f25f3f5a..0cda7a5d2e0 100644 --- a/update/factory/exportHandlerFactory.go +++ b/update/factory/exportHandlerFactory.go @@ -8,6 +8,8 @@ import ( "time" "github.com/multiversx/mx-chain-core-go/core/check" + logger "github.com/multiversx/mx-chain-logger-go" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/dataRetriever" @@ -30,7 +32,6 @@ import ( "github.com/multiversx/mx-chain-go/update/genesis" "github.com/multiversx/mx-chain-go/update/storing" "github.com/multiversx/mx-chain-go/update/sync" - logger "github.com/multiversx/mx-chain-logger-go" ) var log = logger.GetOrCreate("update/factory") @@ -69,6 +70,7 @@ type ArgsExporter struct { TrieSyncerVersion int CheckNodesOnDisk bool NodeOperationMode common.NodeOperation + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory } type exportHandlerFactory struct { @@ -108,6 +110,7 @@ type exportHandlerFactory struct { trieSyncerVersion int checkNodesOnDisk bool nodeOperationMode common.NodeOperation + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // NewExportHandlerFactory creates an exporter factory @@ -266,6 +269,7 @@ func NewExportHandlerFactory(args ArgsExporter) (*exportHandlerFactory, error) { checkNodesOnDisk: args.CheckNodesOnDisk, statusCoreComponents: args.StatusCoreComponents, nodeOperationMode: args.NodeOperationMode, + interceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, } return e, nil @@ -588,6 +592,7 @@ func (e *exportHandlerFactory) createInterceptors() error { FullArchiveInterceptorsContainer: e.fullArchiveInterceptorsContainer, AntifloodHandler: e.networkComponents.InputAntiFloodHandler(), NodeOperationMode: e.nodeOperationMode, + InterceptedDataVerifierFactory: e.interceptedDataVerifierFactory, } fullSyncInterceptors, err := NewFullSyncInterceptorsContainerFactory(argsInterceptors) if err != nil { diff --git a/update/factory/fullSyncInterceptors.go b/update/factory/fullSyncInterceptors.go index 0fe0298c4d6..c7d005e94fb 100644 --- a/update/factory/fullSyncInterceptors.go +++ b/update/factory/fullSyncInterceptors.go @@ -7,6 +7,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/core/throttler" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" @@ -29,25 +30,26 @@ const numGoRoutines = 2000 // fullSyncInterceptorsContainerFactory will handle the creation the interceptors container for shards type fullSyncInterceptorsContainerFactory struct { - mainContainer process.InterceptorsContainer - fullArchiveContainer process.InterceptorsContainer - shardCoordinator sharding.Coordinator - accounts state.AccountsAdapter - store dataRetriever.StorageService - dataPool dataRetriever.PoolsHolder - mainMessenger process.TopicHandler - fullArchiveMessenger process.TopicHandler - nodesCoordinator nodesCoordinator.NodesCoordinator - blockBlackList process.TimeCacher - argInterceptorFactory *interceptorFactory.ArgInterceptedDataFactory - globalThrottler process.InterceptorThrottler - maxTxNonceDeltaAllowed int - addressPubkeyConv core.PubkeyConverter - whiteListHandler update.WhiteListHandler - whiteListerVerifiedTxs update.WhiteListHandler - antifloodHandler process.P2PAntifloodHandler - preferredPeersHolder update.PreferredPeersHolderHandler - nodeOperationMode common.NodeOperation + mainContainer process.InterceptorsContainer + fullArchiveContainer process.InterceptorsContainer + shardCoordinator sharding.Coordinator + accounts state.AccountsAdapter + store dataRetriever.StorageService + dataPool dataRetriever.PoolsHolder + mainMessenger process.TopicHandler + fullArchiveMessenger process.TopicHandler + nodesCoordinator nodesCoordinator.NodesCoordinator + blockBlackList process.TimeCacher + argInterceptorFactory *interceptorFactory.ArgInterceptedDataFactory + globalThrottler process.InterceptorThrottler + maxTxNonceDeltaAllowed int + addressPubkeyConv core.PubkeyConverter + whiteListHandler update.WhiteListHandler + whiteListerVerifiedTxs update.WhiteListHandler + antifloodHandler process.P2PAntifloodHandler + preferredPeersHolder update.PreferredPeersHolderHandler + nodeOperationMode common.NodeOperation + interceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // ArgsNewFullSyncInterceptorsContainerFactory holds the arguments needed for fullSyncInterceptorsContainerFactory @@ -75,6 +77,7 @@ type ArgsNewFullSyncInterceptorsContainerFactory struct { FullArchiveInterceptorsContainer process.InterceptorsContainer AntifloodHandler process.P2PAntifloodHandler NodeOperationMode common.NodeOperation + InterceptedDataVerifierFactory process.InterceptedDataVerifierFactory } // NewFullSyncInterceptorsContainerFactory is responsible for creating a new interceptors factory object @@ -132,6 +135,9 @@ func NewFullSyncInterceptorsContainerFactory( if check.IfNil(args.AntifloodHandler) { return nil, process.ErrNilAntifloodHandler } + if check.IfNil(args.InterceptedDataVerifierFactory) { + return nil, process.ErrNilInterceptedDataVerifierFactory + } argInterceptorFactory := &interceptorFactory.ArgInterceptedDataFactory{ CoreComponents: args.CoreComponents, @@ -163,9 +169,10 @@ func NewFullSyncInterceptorsContainerFactory( whiteListHandler: args.WhiteListHandler, whiteListerVerifiedTxs: args.WhiteListerVerifiedTxs, antifloodHandler: args.AntifloodHandler, - //TODO: inject the real peers holder once we have the peers mapping before epoch bootstrap finishes - preferredPeersHolder: disabled.NewPreferredPeersHolder(), - nodeOperationMode: args.NodeOperationMode, + // TODO: inject the real peers holder once we have the peers mapping before epoch bootstrap finishes + preferredPeersHolder: disabled.NewPreferredPeersHolder(), + nodeOperationMode: args.NodeOperationMode, + interceptedDataVerifierFactory: args.InterceptedDataVerifierFactory, } icf.globalThrottler, err = throttler.NewNumGoRoutinesThrottler(numGoRoutines) @@ -315,7 +322,7 @@ func (ficf *fullSyncInterceptorsContainerFactory) generateShardHeaderInterceptor keys := make([]string, numShards) interceptorsSlice := make([]process.Interceptor, numShards) - //wire up to topics: shardBlocks_0_META, shardBlocks_1_META ... + // wire up to topics: shardBlocks_0_META, shardBlocks_1_META ... for idx := uint32(0); idx < numShards; idx++ { identifierHeader := factory.ShardBlocksTopic + tmpSC.CommunicationIdentifier(idx) if ficf.checkIfInterceptorExists(identifierHeader) { @@ -343,21 +350,28 @@ func (ficf *fullSyncInterceptorsContainerFactory) createOneShardHeaderIntercepto argProcessor := &processor.ArgHdrInterceptorProcessor{ Headers: ficf.dataPool.Headers(), BlockBlackList: ficf.blockBlackList, + Proofs: ficf.dataPool.Proofs(), } hdrProcessor, err := processor.NewHdrInterceptorProcessor(argProcessor) if err != nil { return nil, err } + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: topic, - DataFactory: hdrFactory, - Processor: hdrProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), + Topic: topic, + DataFactory: hdrFactory, + Processor: hdrProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -509,7 +523,7 @@ func (ficf *fullSyncInterceptorsContainerFactory) generateTxInterceptors() error interceptorSlice[int(idx)] = interceptor } - //tx interceptor for metachain topic + // tx interceptor for metachain topic identifierTx := factory.TransactionTopic + shardC.CommunicationIdentifier(core.MetachainShardId) if !ficf.checkIfInterceptorExists(identifierTx) { interceptor, err := ficf.createOneTxInterceptor(identifierTx) @@ -551,17 +565,24 @@ func (ficf *fullSyncInterceptorsContainerFactory) createOneTxInterceptor(topic s return nil, err } + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), - DataFactory: txFactory, - Processor: txProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), - PreferredPeersHolder: ficf.preferredPeersHolder, + Topic: topic, + Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), + Hasher: ficf.argInterceptorFactory.CoreComponents.Hasher(), + DataFactory: txFactory, + Processor: txProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + PreferredPeersHolder: ficf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -586,17 +607,24 @@ func (ficf *fullSyncInterceptorsContainerFactory) createOneUnsignedTxInterceptor return nil, err } + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), - DataFactory: txFactory, - Processor: txProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), - PreferredPeersHolder: ficf.preferredPeersHolder, + Topic: topic, + Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), + Hasher: ficf.argInterceptorFactory.CoreComponents.Hasher(), + DataFactory: txFactory, + Processor: txProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + PreferredPeersHolder: ficf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -621,17 +649,24 @@ func (ficf *fullSyncInterceptorsContainerFactory) createOneRewardTxInterceptor(t return nil, err } + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), - DataFactory: txFactory, - Processor: txProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), - PreferredPeersHolder: ficf.preferredPeersHolder, + Topic: topic, + Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), + Hasher: ficf.argInterceptorFactory.CoreComponents.Hasher(), + DataFactory: txFactory, + Processor: txProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + PreferredPeersHolder: ficf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -694,16 +729,22 @@ func (ficf *fullSyncInterceptorsContainerFactory) createOneMiniBlocksInterceptor return nil, err } + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: topic, - DataFactory: txFactory, - Processor: txBlockBodyProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), - PreferredPeersHolder: ficf.preferredPeersHolder, + Topic: topic, + DataFactory: txFactory, + Processor: txBlockBodyProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + PreferredPeersHolder: ficf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -719,7 +760,11 @@ func (ficf *fullSyncInterceptorsContainerFactory) generateMetachainHeaderInterce return nil } - hdrFactory, err := interceptorFactory.NewInterceptedMetaHeaderDataFactory(ficf.argInterceptorFactory) + argsInterceptedMetaHeaderFactory := interceptorFactory.ArgInterceptedMetaHeaderFactory{ + ArgInterceptedDataFactory: *ficf.argInterceptorFactory, + } + + hdrFactory, err := interceptorFactory.NewInterceptedMetaHeaderDataFactory(&argsInterceptedMetaHeaderFactory) if err != nil { return err } @@ -727,23 +772,30 @@ func (ficf *fullSyncInterceptorsContainerFactory) generateMetachainHeaderInterce argProcessor := &processor.ArgHdrInterceptorProcessor{ Headers: ficf.dataPool.Headers(), BlockBlackList: ficf.blockBlackList, + Proofs: ficf.dataPool.Proofs(), } hdrProcessor, err := processor.NewHdrInterceptorProcessor(argProcessor) if err != nil { return err } - //only one metachain header topic + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(identifierHdr) + if err != nil { + return err + } + + // only one metachain header topic interceptor, err := interceptors.NewSingleDataInterceptor( interceptors.ArgSingleDataInterceptor{ - Topic: identifierHdr, - DataFactory: hdrFactory, - Processor: hdrProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), - PreferredPeersHolder: ficf.preferredPeersHolder, + Topic: identifierHdr, + DataFactory: hdrFactory, + Processor: hdrProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + PreferredPeersHolder: ficf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -769,17 +821,24 @@ func (ficf *fullSyncInterceptorsContainerFactory) createOneTrieNodesInterceptor( return nil, err } + interceptedDataVerifier, err := ficf.interceptedDataVerifierFactory.Create(topic) + if err != nil { + return nil, err + } + interceptor, err := interceptors.NewMultiDataInterceptor( interceptors.ArgMultiDataInterceptor{ - Topic: topic, - Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), - DataFactory: trieNodesFactory, - Processor: trieNodesProcessor, - Throttler: ficf.globalThrottler, - AntifloodHandler: ficf.antifloodHandler, - WhiteListRequest: ficf.whiteListHandler, - CurrentPeerId: ficf.mainMessenger.ID(), - PreferredPeersHolder: ficf.preferredPeersHolder, + Topic: topic, + Marshalizer: ficf.argInterceptorFactory.CoreComponents.InternalMarshalizer(), + Hasher: ficf.argInterceptorFactory.CoreComponents.Hasher(), + DataFactory: trieNodesFactory, + Processor: trieNodesProcessor, + Throttler: ficf.globalThrottler, + AntifloodHandler: ficf.antifloodHandler, + WhiteListRequest: ficf.whiteListHandler, + CurrentPeerId: ficf.mainMessenger.ID(), + PreferredPeersHolder: ficf.preferredPeersHolder, + InterceptedDataVerifier: interceptedDataVerifier, }, ) if err != nil { @@ -811,7 +870,6 @@ func (ficf *fullSyncInterceptorsContainerFactory) generateRewardTxInterceptors() if err != nil { return err } - keys[int(idx)] = identifierScr interceptorSlice[int(idx)] = interceptor } diff --git a/update/genesis/export.go b/update/genesis/export.go index ba4e678a0f8..c87e977f487 100644 --- a/update/genesis/export.go +++ b/update/genesis/export.go @@ -17,6 +17,7 @@ import ( "github.com/multiversx/mx-chain-core-go/marshal" "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/common/errChan" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/sharding" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state" @@ -444,7 +445,7 @@ func (se *stateExport) exportValidatorInfo(key string, validatorInfo *state.Shar func (se *stateExport) exportNodesSetupJson(validators state.ShardValidatorsInfoMapHandler) error { acceptedListsForExport := []common.PeerType{common.EligibleList, common.WaitingList, common.JailedList} - initialNodes := make([]*sharding.InitialNode, 0) + initialNodes := make([]*config.InitialNodeConfig, 0) for _, validator := range validators.GetAllValidatorsInfo() { if shouldExportValidator(validator, acceptedListsForExport) { @@ -459,7 +460,7 @@ func (se *stateExport) exportNodesSetupJson(validators state.ShardValidatorsInfo return nil } - initialNodes = append(initialNodes, &sharding.InitialNode{ + initialNodes = append(initialNodes, &config.InitialNodeConfig{ PubKey: pubKey, Address: rewardAddress, InitialRating: validator.GetRating(), @@ -471,20 +472,10 @@ func (se *stateExport) exportNodesSetupJson(validators state.ShardValidatorsInfo return strings.Compare(initialNodes[i].PubKey, initialNodes[j].PubKey) < 0 }) - genesisNodesSetupHandler := se.genesisNodesSetupHandler - nodesSetup := &sharding.NodesSetup{ - StartTime: genesisNodesSetupHandler.GetStartTime(), - RoundDuration: genesisNodesSetupHandler.GetRoundDuration(), - ConsensusGroupSize: genesisNodesSetupHandler.GetShardConsensusGroupSize(), - MinNodesPerShard: genesisNodesSetupHandler.MinNumberOfShardNodes(), - MetaChainConsensusGroupSize: genesisNodesSetupHandler.GetMetaConsensusGroupSize(), - MetaChainMinNodes: genesisNodesSetupHandler.MinNumberOfMetaNodes(), - Hysteresis: genesisNodesSetupHandler.GetHysteresis(), - Adaptivity: genesisNodesSetupHandler.GetAdaptivity(), - InitialNodes: initialNodes, - } + exportedNodesConfig := se.genesisNodesSetupHandler.ExportNodesConfig() + exportedNodesConfig.InitialNodes = initialNodes - nodesSetupBytes, err := json.MarshalIndent(nodesSetup, "", " ") + nodesSetupBytes, err := json.MarshalIndent(exportedNodesConfig, "", " ") if err != nil { return err } diff --git a/update/interface.go b/update/interface.go index 277861ceada..385e9b62e56 100644 --- a/update/interface.go +++ b/update/interface.go @@ -8,6 +8,7 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-go/common" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" "github.com/multiversx/mx-chain-go/state" @@ -262,6 +263,7 @@ type GenesisNodesSetupHandler interface { GetAdaptivity() bool NumberOfShards() uint32 MinNumberOfNodes() uint32 + ExportNodesConfig() config.NodesConfig IsInterfaceNil() bool } diff --git a/update/mock/epochStartNotifierStub.go b/update/mock/epochStartNotifierStub.go index 0a7b89387f5..96bb821a1f1 100644 --- a/update/mock/epochStartNotifierStub.go +++ b/update/mock/epochStartNotifierStub.go @@ -2,6 +2,7 @@ package mock import ( "github.com/multiversx/mx-chain-core-go/data" + "github.com/multiversx/mx-chain-go/epochStart" ) diff --git a/update/mock/nodesSetupHandlerStub.go b/update/mock/nodesSetupHandlerStub.go index 499e14187c2..9c6b64e99a7 100644 --- a/update/mock/nodesSetupHandlerStub.go +++ b/update/mock/nodesSetupHandlerStub.go @@ -3,6 +3,7 @@ package mock import ( "time" + "github.com/multiversx/mx-chain-go/config" "github.com/multiversx/mx-chain-go/sharding/nodesCoordinator" ) @@ -10,6 +11,7 @@ import ( type GenesisNodesSetupHandlerStub struct { InitialNodesInfoForShardCalled func(shardId uint32) ([]nodesCoordinator.GenesisNodeInfoHandler, []nodesCoordinator.GenesisNodeInfoHandler, error) InitialNodesInfoCalled func() (map[uint32][]nodesCoordinator.GenesisNodeInfoHandler, map[uint32][]nodesCoordinator.GenesisNodeInfoHandler) + ExportNodesConfigCalled func() config.NodesConfig GetStartTimeCalled func() int64 GetRoundDurationCalled func() uint64 GetChainIdCalled func() string @@ -150,6 +152,15 @@ func (g *GenesisNodesSetupHandlerStub) MinNumberOfNodes() uint32 { return 1 } +// ExportNodesConfig - +func (g *GenesisNodesSetupHandlerStub) ExportNodesConfig() config.NodesConfig { + if g.ExportNodesConfigCalled != nil { + return g.ExportNodesConfigCalled() + } + + return config.NodesConfig{} +} + // IsInterfaceNil - func (g *GenesisNodesSetupHandlerStub) IsInterfaceNil() bool { return g == nil diff --git a/update/sync/coordinator_test.go b/update/sync/coordinator_test.go index b56b2d8f99a..e5f3067dd33 100644 --- a/update/sync/coordinator_test.go +++ b/update/sync/coordinator_test.go @@ -11,18 +11,20 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" dataTransaction "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" stateMock "github.com/multiversx/mx-chain-go/testscommon/state" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/testscommon/syncer" trieMock "github.com/multiversx/mx-chain-go/testscommon/trie" "github.com/multiversx/mx-chain-go/update" "github.com/multiversx/mx-chain-go/update/mock" - "github.com/stretchr/testify/require" ) func createHeaderSyncHandler(retErr bool) update.HeaderSyncHandler { @@ -71,7 +73,7 @@ func createPendingMiniBlocksSyncHandler() update.EpochStartPendingMiniBlocksSync mb := &block.MiniBlock{TxHashes: [][]byte{txHash}} args := ArgsNewPendingMiniBlocksSyncer{ Storage: &storageStubs.StorerStub{}, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, val interface{})) {}, PeekCalled: func(key []byte) (value interface{}, ok bool) { return mb, true diff --git a/update/sync/syncEpochStartShardHeaders.go b/update/sync/syncEpochStartShardHeaders.go index f7a86994214..d66f14960a4 100644 --- a/update/sync/syncEpochStartShardHeaders.go +++ b/update/sync/syncEpochStartShardHeaders.go @@ -1,6 +1,7 @@ package sync import ( + "bytes" "context" "sync" "time" @@ -9,11 +10,15 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/marshal" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/update" ) +// TODO: there is some duplicated code between this syncer and the other syncers in this package that could be refactored + var _ update.PendingEpochStartShardHeaderSyncHandler = (*pendingEpochStartShardHeader)(nil) type pendingEpochStartShardHeader struct { @@ -22,6 +27,7 @@ type pendingEpochStartShardHeader struct { epochStartHash []byte latestReceivedHeader data.HeaderHandler latestReceivedHash []byte + latestReceivedProof data.HeaderProofHandler targetEpoch uint32 targetShardId uint32 headersPool dataRetriever.HeadersPool @@ -32,13 +38,17 @@ type pendingEpochStartShardHeader struct { synced bool requestHandler process.RequestHandler waitTimeBetweenRequests time.Duration + enableEpochsHandler common.EnableEpochsHandler + proofsPool dataRetriever.ProofsPool } // ArgsPendingEpochStartShardHeaderSyncer defines the arguments needed for the sycner type ArgsPendingEpochStartShardHeaderSyncer struct { - HeadersPool dataRetriever.HeadersPool - Marshalizer marshal.Marshalizer - RequestHandler process.RequestHandler + HeadersPool dataRetriever.HeadersPool + ProofsPool dataRetriever.ProofsPool + Marshalizer marshal.Marshalizer + RequestHandler process.RequestHandler + EnableEpochsHandler common.EnableEpochsHandler } // NewPendingEpochStartShardHeaderSyncer creates a syncer for all pending miniblocks @@ -46,12 +56,18 @@ func NewPendingEpochStartShardHeaderSyncer(args ArgsPendingEpochStartShardHeader if check.IfNil(args.HeadersPool) { return nil, update.ErrNilHeadersPool } + if check.IfNil(args.ProofsPool) { + return nil, dataRetriever.ErrNilProofsPool + } if check.IfNil(args.Marshalizer) { return nil, dataRetriever.ErrNilMarshalizer } if check.IfNil(args.RequestHandler) { return nil, process.ErrNilRequestHandler } + if check.IfNil(args.EnableEpochsHandler) { + return nil, update.ErrNilEnableEpochsHandler + } p := &pendingEpochStartShardHeader{ mutPending: sync.RWMutex{}, @@ -60,6 +76,7 @@ func NewPendingEpochStartShardHeaderSyncer(args ArgsPendingEpochStartShardHeader targetEpoch: 0, targetShardId: 0, headersPool: args.HeadersPool, + proofsPool: args.ProofsPool, chReceived: make(chan bool), chNew: make(chan bool), requestHandler: args.RequestHandler, @@ -67,9 +84,11 @@ func NewPendingEpochStartShardHeaderSyncer(args ArgsPendingEpochStartShardHeader synced: false, marshaller: args.Marshalizer, waitTimeBetweenRequests: args.RequestHandler.RequestInterval(), + enableEpochsHandler: args.EnableEpochsHandler, } p.headersPool.RegisterHandler(p.receivedHeader) + p.proofsPool.RegisterHandler(p.receivedProof) return p, nil } @@ -79,6 +98,14 @@ func (p *pendingEpochStartShardHeader) SyncEpochStartShardHeader(shardId uint32, return p.syncEpochStartShardHeader(shardId, epoch, startNonce, ctx) } +func (p *pendingEpochStartShardHeader) hasProof(shardID uint32, hash []byte, epoch uint32) bool { + if !p.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, epoch) { + return true + } + + return p.proofsPool.HasProof(shardID, hash) +} + func (p *pendingEpochStartShardHeader) syncEpochStartShardHeader(shardId uint32, epoch uint32, startNonce uint64, ctx context.Context) error { _ = core.EmptyChannel(p.chReceived) _ = core.EmptyChannel(p.chNew) @@ -94,6 +121,7 @@ func (p *pendingEpochStartShardHeader) syncEpochStartShardHeader(shardId uint32, p.mutPending.Lock() p.stopSyncing = false p.requestHandler.RequestShardHeaderByNonce(shardId, nonce+1) + p.requestHandler.RequestEquivalentProofByNonce(shardId, nonce+1) p.mutPending.Unlock() select { @@ -130,7 +158,18 @@ func (p *pendingEpochStartShardHeader) receivedHeader(header data.HeaderHandler, p.latestReceivedHash = headerHash p.latestReceivedHeader = header + if !p.hasProof(header.GetShardID(), headerHash, header.GetEpoch()) { + go p.requestHandler.RequestEquivalentProofByHash(header.GetShardID(), headerHash) + p.mutPending.Unlock() + return + } + p.mutPending.Unlock() + p.updateReceivedHeaderAndProof(header, headerHash) +} + +func (p *pendingEpochStartShardHeader) updateReceivedHeaderAndProof(header data.HeaderHandler, headerHash []byte) { + p.mutPending.Lock() if header.GetEpoch() != p.targetEpoch || !header.IsStartOfEpochBlock() { p.mutPending.Unlock() p.chNew <- true @@ -144,6 +183,25 @@ func (p *pendingEpochStartShardHeader) receivedHeader(header data.HeaderHandler, p.chReceived <- true } +func (p *pendingEpochStartShardHeader) receivedProof(proof data.HeaderProofHandler) { + p.mutPending.Lock() + if p.stopSyncing { + p.mutPending.Unlock() + return + } + if !check.IfNil(p.latestReceivedProof) && bytes.Equal(proof.GetHeaderHash(), p.latestReceivedProof.GetHeaderHash()) { + p.mutPending.Unlock() + return + } + if !bytes.Equal(proof.GetHeaderHash(), p.latestReceivedHash) { + p.mutPending.Unlock() + return + } + p.latestReceivedProof = proof + p.mutPending.Unlock() + p.updateReceivedHeaderAndProof(p.latestReceivedHeader, p.latestReceivedHash) +} + // GetEpochStartHeader returns the synced epoch start header func (p *pendingEpochStartShardHeader) GetEpochStartHeader() (data.HeaderHandler, []byte, error) { p.mutPending.RLock() diff --git a/update/sync/syncEpochStartShardHeaders_test.go b/update/sync/syncEpochStartShardHeaders_test.go index 201307864d8..7bf4719555d 100644 --- a/update/sync/syncEpochStartShardHeaders_test.go +++ b/update/sync/syncEpochStartShardHeaders_test.go @@ -6,21 +6,28 @@ import ( "testing" "time" + "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/require" + + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/update" "github.com/multiversx/mx-chain-go/update/mock" - "github.com/stretchr/testify/require" ) func createMockArgsPendingEpochStartShardHeader() ArgsPendingEpochStartShardHeaderSyncer { return ArgsPendingEpochStartShardHeaderSyncer{ - HeadersPool: &mock.HeadersCacherStub{}, - Marshalizer: &mock.MarshalizerFake{}, - RequestHandler: &testscommon.RequestHandlerStub{}, + HeadersPool: &mock.HeadersCacherStub{}, + Marshalizer: &mock.MarshalizerFake{}, + RequestHandler: &testscommon.RequestHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, + ProofsPool: &dataRetrieverMocks.ProofsPoolMock{}, } } @@ -82,16 +89,14 @@ func TestSyncEpochStartShardHeader_Success(t *testing.T) { Epoch: epoch, EpochStartMetaHash: []byte("metaHash"), } - - headersPool := &mock.HeadersCacherStub{} - args := ArgsPendingEpochStartShardHeaderSyncer{ - HeadersPool: headersPool, - Marshalizer: &mock.MarshalizerFake{}, - RequestHandler: &testscommon.RequestHandlerStub{ - RequestShardHeaderByNonceCalled: func(shardID uint32, nonce uint64) {}, - }, + proof := &block.HeaderProof{ + HeaderShardId: shardID, + HeaderNonce: startNonce + 2, + HeaderHash: headerHash, + HeaderEpoch: epoch, } + args := createPendingEpochStartShardHeaderSyncerArgs() syncer, err := NewPendingEpochStartShardHeaderSyncer(args) require.Nil(t, err) @@ -103,11 +108,20 @@ func TestSyncEpochStartShardHeader_Success(t *testing.T) { Nonce: startNonce + 1, Epoch: epoch - 1, } - syncer.receivedHeader(h1, []byte("hash1")) + h1Hash := []byte("hash1") + p1 := &block.HeaderProof{ + HeaderShardId: shardID, + HeaderNonce: startNonce + 1, + HeaderHash: h1Hash, + HeaderEpoch: epoch - 1, + } + syncer.receivedHeader(h1, h1Hash) + syncer.receivedProof(p1) // Wait a bit, then receive epoch start header time.Sleep(100 * time.Millisecond) syncer.receivedHeader(header, headerHash) + syncer.receivedProof(proof) }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -129,15 +143,7 @@ func TestSyncEpochStartShardHeader_Timeout(t *testing.T) { epoch := uint32(10) startNonce := uint64(100) - headersPool := &mock.HeadersCacherStub{} - args := ArgsPendingEpochStartShardHeaderSyncer{ - HeadersPool: headersPool, - Marshalizer: &mock.MarshalizerFake{}, - RequestHandler: &testscommon.RequestHandlerStub{ - RequestShardHeaderByNonceCalled: func(shardID uint32, nonce uint64) {}, - }, - } - + args := createPendingEpochStartShardHeaderSyncerArgs() syncer, err := NewPendingEpochStartShardHeaderSyncer(args) require.Nil(t, err) @@ -175,16 +181,14 @@ func TestSyncEpochStartShardHeader_ClearFields(t *testing.T) { Epoch: epoch, EpochStartMetaHash: []byte("metaHash"), } - - headersPool := &mock.HeadersCacherStub{} - args := ArgsPendingEpochStartShardHeaderSyncer{ - HeadersPool: headersPool, - Marshalizer: &mock.MarshalizerFake{}, - RequestHandler: &testscommon.RequestHandlerStub{ - RequestShardHeaderByNonceCalled: func(shardID uint32, nonce uint64) {}, - }, + proof := &block.HeaderProof{ + HeaderShardId: shardID, + HeaderNonce: startNonce + 1, + HeaderHash: headerHash, + HeaderEpoch: epoch, } + args := createPendingEpochStartShardHeaderSyncerArgs() syncer, err := NewPendingEpochStartShardHeaderSyncer(args) require.Nil(t, err) @@ -192,6 +196,7 @@ func TestSyncEpochStartShardHeader_ClearFields(t *testing.T) { go func() { time.Sleep(100 * time.Millisecond) syncer.receivedHeader(header, headerHash) + syncer.receivedProof(proof) }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -228,16 +233,14 @@ func TestSyncEpochStartShardHeader_DifferentShardIDsShouldNotInterfere(t *testin Epoch: epoch, EpochStartMetaHash: []byte("metaHash"), } - - headersPool := &mock.HeadersCacherStub{} - args := ArgsPendingEpochStartShardHeaderSyncer{ - HeadersPool: headersPool, - Marshalizer: &mock.MarshalizerFake{}, - RequestHandler: &testscommon.RequestHandlerStub{ - RequestShardHeaderByNonceCalled: func(shardID uint32, nonce uint64) {}, - }, + proof := &block.HeaderProof{ + HeaderShardId: shardID, + HeaderNonce: startNonce + 2, + HeaderHash: headerHash, + HeaderEpoch: epoch, } + args := createPendingEpochStartShardHeaderSyncerArgs() syncer, err := NewPendingEpochStartShardHeaderSyncer(args) require.Nil(t, err) @@ -254,6 +257,7 @@ func TestSyncEpochStartShardHeader_DifferentShardIDsShouldNotInterfere(t *testin // Wait and then send correct shard header time.Sleep(100 * time.Millisecond) syncer.receivedHeader(header, headerHash) + syncer.receivedProof(proof) }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -276,12 +280,19 @@ func TestSyncEpochStartShardHeader_NonEpochStartHeadersShouldTriggerNextAttempt( startNonce := uint64(100) headerHash := []byte("epochStartHash") + nonEpochStartHeaderHash := []byte("nonEpochStartHash") nonEpochStartHeader := &block.Header{ ShardID: shardID, Nonce: startNonce + 1, Epoch: epoch - 1, // not the target epoch EpochStartMetaHash: []byte("ignoreMetaHash"), } + nonEpochStartProof := &block.HeaderProof{ + HeaderShardId: shardID, + HeaderNonce: startNonce + 1, + HeaderHash: []byte("ignoreHash"), + HeaderEpoch: epoch - 1, + } epochStartHeader := &block.Header{ ShardID: shardID, @@ -289,26 +300,26 @@ func TestSyncEpochStartShardHeader_NonEpochStartHeadersShouldTriggerNextAttempt( Epoch: epoch, EpochStartMetaHash: []byte("metaHash"), } - - headersPool := &mock.HeadersCacherStub{} - args := ArgsPendingEpochStartShardHeaderSyncer{ - HeadersPool: headersPool, - Marshalizer: &mock.MarshalizerFake{}, - RequestHandler: &testscommon.RequestHandlerStub{ - RequestShardHeaderByNonceCalled: func(shardID uint32, nonce uint64) {}, - }, + epochStartProof := &block.HeaderProof{ + HeaderShardId: shardID, + HeaderNonce: startNonce + 2, + HeaderHash: headerHash, + HeaderEpoch: epoch, } + args := createPendingEpochStartShardHeaderSyncerArgs() syncer, err := NewPendingEpochStartShardHeaderSyncer(args) require.Nil(t, err) go func() { // first receive non-epoch start header - syncer.receivedHeader(nonEpochStartHeader, []byte("nonEpochStartHash")) + syncer.receivedHeader(nonEpochStartHeader, nonEpochStartHeaderHash) + syncer.receivedProof(nonEpochStartProof) // after a small delay, receive epoch start header time.Sleep(100 * time.Millisecond) syncer.receivedHeader(epochStartHeader, headerHash) + syncer.receivedProof(epochStartProof) }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -337,16 +348,14 @@ func TestSyncEpochStartShardHeader_MultipleGoroutines(t *testing.T) { Epoch: epoch, EpochStartMetaHash: []byte("methaHash"), } - - headersPool := &mock.HeadersCacherStub{} - args := ArgsPendingEpochStartShardHeaderSyncer{ - HeadersPool: headersPool, - Marshalizer: &mock.MarshalizerFake{}, - RequestHandler: &testscommon.RequestHandlerStub{ - RequestShardHeaderByNonceCalled: func(shardID uint32, nonce uint64) {}, - }, + epochStartProof := &block.HeaderProof{ + HeaderShardId: shardID, + HeaderNonce: startNonce + 5, + HeaderHash: headerHash, + HeaderEpoch: epoch, } + args := createPendingEpochStartShardHeaderSyncerArgs() syncer, err := NewPendingEpochStartShardHeaderSyncer(args) require.Nil(t, err) @@ -356,6 +365,7 @@ func TestSyncEpochStartShardHeader_MultipleGoroutines(t *testing.T) { go func() { time.Sleep(200 * time.Millisecond) syncer.receivedHeader(epochStartHeader, headerHash) + syncer.receivedProof(epochStartProof) }() // Use a wait group to wait for all noise goroutines to complete. @@ -379,12 +389,20 @@ func TestSyncEpochStartShardHeader_MultipleGoroutines(t *testing.T) { } for nonce := startNonce + 1; nonce < startNonce+5; nonce++ { + noiseHash := []byte("noiseHash") hdr := &block.Header{ ShardID: localShardID, Nonce: nonce, Epoch: localEpoch, } - syncer.receivedHeader(hdr, []byte("noiseHash")) + noiseProof := &block.HeaderProof{ + HeaderShardId: localShardID, + HeaderNonce: nonce, + HeaderHash: noiseHash, + HeaderEpoch: localEpoch, + } + syncer.receivedHeader(hdr, noiseHash) + syncer.receivedProof(noiseProof) time.Sleep(10 * time.Millisecond) // small delay between headers } }(i) @@ -413,3 +431,25 @@ func TestPendingEpochStartShardHeader_IsInterfaceNil(t *testing.T) { p = &pendingEpochStartShardHeader{} require.False(t, p.IsInterfaceNil()) } + +func createPendingEpochStartShardHeaderSyncerArgs() ArgsPendingEpochStartShardHeaderSyncer { + headersPool := &mock.HeadersCacherStub{} + proofsPool := &dataRetrieverMocks.ProofsPoolMock{} + args := ArgsPendingEpochStartShardHeaderSyncer{ + HeadersPool: headersPool, + Marshalizer: &mock.MarshalizerFake{}, + RequestHandler: &testscommon.RequestHandlerStub{ + RequestShardHeaderByNonceCalled: func(shardID uint32, nonce uint64) {}, + }, + ProofsPool: proofsPool, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{ + IsFlagEnabledInEpochCalled: func(flag core.EnableEpochFlag, epoch uint32) bool { + return flag == common.AndromedaFlag + }, + IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool { + return flag == common.AndromedaFlag + }, + }, + } + return args +} diff --git a/update/sync/syncHeadersByHash.go b/update/sync/syncHeadersByHash.go index 93a46b8d951..eb429d9b482 100644 --- a/update/sync/syncHeadersByHash.go +++ b/update/sync/syncHeadersByHash.go @@ -9,6 +9,7 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/marshal" + "github.com/multiversx/mx-chain-go/common" "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/storage" @@ -22,6 +23,7 @@ type syncHeadersByHash struct { mapHeaders map[string]data.HeaderHandler mapHashes map[string]struct{} pool dataRetriever.HeadersPool + proofsPool dataRetriever.ProofsPool storage update.HistoryStorer chReceivedAll chan bool marshalizer marshal.Marshalizer @@ -29,14 +31,17 @@ type syncHeadersByHash struct { syncedAll bool requestHandler process.RequestHandler waitTimeBetweenRequests time.Duration + enableEpochsHandler common.EnableEpochsHandler } // ArgsNewMissingHeadersByHashSyncer defines the arguments needed for the sycner type ArgsNewMissingHeadersByHashSyncer struct { - Storage storage.Storer - Cache dataRetriever.HeadersPool - Marshalizer marshal.Marshalizer - RequestHandler process.RequestHandler + Storage storage.Storer + Cache dataRetriever.HeadersPool + ProofsPool dataRetriever.ProofsPool + Marshalizer marshal.Marshalizer + RequestHandler process.RequestHandler + EnableEpochsHandler common.EnableEpochsHandler } // NewMissingheadersByHashSyncer creates a syncer for all missing headers @@ -47,18 +52,25 @@ func NewMissingheadersByHashSyncer(args ArgsNewMissingHeadersByHashSyncer) (*syn if check.IfNil(args.Cache) { return nil, update.ErrNilCacher } + if check.IfNil(args.ProofsPool) { + return nil, dataRetriever.ErrNilProofsPool + } if check.IfNil(args.Marshalizer) { return nil, dataRetriever.ErrNilMarshalizer } if check.IfNil(args.RequestHandler) { return nil, process.ErrNilRequestHandler } + if check.IfNil(args.EnableEpochsHandler) { + return nil, process.ErrNilEnableEpochsHandler + } p := &syncHeadersByHash{ mutMissingHdrs: sync.Mutex{}, mapHeaders: make(map[string]data.HeaderHandler), mapHashes: make(map[string]struct{}), pool: args.Cache, + proofsPool: args.ProofsPool, storage: args.Storage, chReceivedAll: make(chan bool), requestHandler: args.RequestHandler, @@ -66,9 +78,11 @@ func NewMissingheadersByHashSyncer(args ArgsNewMissingHeadersByHashSyncer) (*syn syncedAll: false, marshalizer: args.Marshalizer, waitTimeBetweenRequests: args.RequestHandler.RequestInterval(), + enableEpochsHandler: args.EnableEpochsHandler, } p.pool.RegisterHandler(p.receivedHeader) + p.proofsPool.RegisterHandler(p.receivedProof) return p, nil } @@ -84,33 +98,21 @@ func (m *syncHeadersByHash) SyncMissingHeadersByHash(shardIDs []uint32, headersH for { requestedHdrs := 0 + requestedProofs := 0 m.mutMissingHdrs.Lock() m.stopSyncing = false for hash, shardId := range mapHashesToRequest { - if _, ok := m.mapHeaders[hash]; ok { - delete(mapHashesToRequest, hash) - continue + requestedHeader, requestedProof := m.updateMapsAndRequestIfNeeded(shardId, hash, mapHashesToRequest) + if requestedHeader { + requestedHdrs++ } - - m.mapHashes[hash] = struct{}{} - header, ok := m.getHeaderFromPoolOrStorage([]byte(hash)) - if ok { - m.mapHeaders[hash] = header - delete(mapHashesToRequest, hash) - continue + if requestedProof { + requestedProofs++ } - - requestedHdrs++ - if shardId == core.MetachainShardId { - m.requestHandler.RequestMetaHeader([]byte(hash)) - continue - } - - m.requestHandler.RequestShardHeader(shardId, []byte(hash)) } - if requestedHdrs == 0 { + if requestedHdrs == 0 && requestedProofs == 0 { m.stopSyncing = true m.syncedAll = true m.mutMissingHdrs.Unlock() @@ -137,6 +139,64 @@ func (m *syncHeadersByHash) SyncMissingHeadersByHash(shardIDs []uint32, headersH } } +func (m *syncHeadersByHash) updateMapsAndRequestIfNeeded( + shardId uint32, + hash string, + mapHashesToRequest map[string]uint32, +) (bool, bool) { + hasProof := false + hasHeader := false + hasRequestedProof := false + if header, ok := m.mapHeaders[hash]; ok { + hasHeader = ok + hasProof = m.hasProof(shardId, []byte(hash), header.GetEpoch()) + if hasProof { + delete(mapHashesToRequest, hash) + return false, false + } + } + + m.mapHashes[hash] = struct{}{} + header, ok := m.getHeaderFromPoolOrStorage([]byte(hash)) + if ok { + hasHeader = ok + hasProof = m.hasProof(shardId, []byte(hash), header.GetEpoch()) + if hasProof { + m.mapHeaders[hash] = header + delete(mapHashesToRequest, hash) + return false, false + } + } + + // if header is missing, do not request the proof + // if a proof is needed for the header, it will be requested when header is received + if hasHeader { + if !hasProof { + hasRequestedProof = true + m.requestHandler.RequestEquivalentProofByHash(shardId, []byte(hash)) + } + + return false, hasRequestedProof + } + + if shardId == core.MetachainShardId { + m.requestHandler.RequestMetaHeader([]byte(hash)) + return true, hasRequestedProof + } + + m.requestHandler.RequestShardHeader(shardId, []byte(hash)) + + return true, hasRequestedProof +} + +func (m *syncHeadersByHash) hasProof(shardID uint32, hash []byte, epoch uint32) bool { + if !m.enableEpochsHandler.IsFlagEnabledInEpoch(common.AndromedaFlag, epoch) { + return true + } + + return m.proofsPool.HasProof(shardID, hash) +} + // receivedHeader is a callback function when a new header was received // it will further ask for missing transactions func (m *syncHeadersByHash) receivedHeader(hdrHandler data.HeaderHandler, hdrHash []byte) { @@ -151,6 +211,12 @@ func (m *syncHeadersByHash) receivedHeader(hdrHandler data.HeaderHandler, hdrHas return } + if !m.hasProof(hdrHandler.GetShardID(), hdrHash, hdrHandler.GetEpoch()) { + go m.requestHandler.RequestEquivalentProofByHash(hdrHandler.GetShardID(), hdrHash) + m.mutMissingHdrs.Unlock() + return + } + if _, ok := m.mapHeaders[string(hdrHash)]; ok { m.mutMissingHdrs.Unlock() return @@ -164,6 +230,36 @@ func (m *syncHeadersByHash) receivedHeader(hdrHandler data.HeaderHandler, hdrHas } } +func (m *syncHeadersByHash) receivedProof(proofHandler data.HeaderProofHandler) { + m.mutMissingHdrs.Lock() + if m.stopSyncing { + m.mutMissingHdrs.Unlock() + return + } + + hdrHash := proofHandler.GetHeaderHash() + if _, ok := m.mapHashes[string(hdrHash)]; !ok { + m.mutMissingHdrs.Unlock() + return + } + + hdrHandler, ok := m.mapHeaders[string(hdrHash)] + if !ok { + hdrHandler, ok = m.getHeaderFromPoolOrStorage(hdrHash) + if !ok { + m.mutMissingHdrs.Unlock() + return + } + } + + m.mapHeaders[string(hdrHash)] = hdrHandler + receivedAll := len(m.mapHashes) == len(m.mapHeaders) + m.mutMissingHdrs.Unlock() + if receivedAll { + m.chReceivedAll <- true + } +} + func (m *syncHeadersByHash) getHeaderFromPoolOrStorage(hash []byte) (data.HeaderHandler, bool) { header, ok := m.getHeaderFromPool(hash) if ok { diff --git a/update/sync/syncHeadersByHash_test.go b/update/sync/syncHeadersByHash_test.go index fa2cd5ede47..924ace354dc 100644 --- a/update/sync/syncHeadersByHash_test.go +++ b/update/sync/syncHeadersByHash_test.go @@ -3,6 +3,7 @@ package sync import ( "context" "errors" + "sync" "testing" "time" @@ -10,7 +11,10 @@ import ( "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" "github.com/multiversx/mx-chain-go/dataRetriever" + "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + dataRetrieverMocks "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" + "github.com/multiversx/mx-chain-go/testscommon/enableEpochsHandlerMock" "github.com/multiversx/mx-chain-go/testscommon/genericMocks" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/update" @@ -41,6 +45,14 @@ func TestNewMissingheadersByHashSyncer_NilParamsShouldErr(t *testing.T) { nilRequestHandlerArgs.RequestHandler = nil testInput[nilRequestHandlerArgs] = update.ErrNilRequestHandler + nilProofsPoolArgs := okArgs + nilProofsPoolArgs.ProofsPool = nil + testInput[nilProofsPoolArgs] = dataRetriever.ErrNilProofsPool + + nilEnableEpochsHandlerArgs := okArgs + nilEnableEpochsHandlerArgs.EnableEpochsHandler = nil + testInput[nilEnableEpochsHandlerArgs] = process.ErrNilEnableEpochsHandler + for args, expectedErr := range testInput { mhhs, err := NewMissingheadersByHashSyncer(args) require.True(t, check.IfNil(mhhs)) @@ -149,19 +161,26 @@ func TestSyncHeadersByHash_GetHeadersShouldReceiveAndReturnOkMb(t *testing.T) { handlerToNotify = handler }, } - mhhs, _ := NewMissingheadersByHashSyncer(args) - require.NotNil(t, mhhs) + var wg sync.WaitGroup expectedHash := []byte("hash") expectedMB := &block.MetaBlock{Nonce: 37} - go func() { - time.Sleep(10 * time.Millisecond) - handlerToNotify(expectedMB, expectedHash) - }() - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - err := mhhs.SyncMissingHeadersByHash([]uint32{0}, [][]byte{[]byte("hash")}, ctx) + args.RequestHandler = &testscommon.RequestHandlerStub{ + RequestShardHeaderCalled: func(shardID uint32, hash []byte) { + wg.Add(1) + go func() { + handlerToNotify(expectedMB, expectedHash) + wg.Done() + }() + }, + } + + mhhs, _ := NewMissingheadersByHashSyncer(args) + require.NotNil(t, mhhs) + + err := mhhs.SyncMissingHeadersByHash([]uint32{0}, [][]byte{[]byte("hash")}, context.Background()) require.NoError(t, err) - cancel() + wg.Wait() res, err := mhhs.GetHeaders() require.NoError(t, err) @@ -174,9 +193,11 @@ func TestSyncHeadersByHash_GetHeadersShouldReceiveAndReturnOkMb(t *testing.T) { func getMisingHeadersByHashSyncerArgs() ArgsNewMissingHeadersByHashSyncer { return ArgsNewMissingHeadersByHashSyncer{ - Storage: genericMocks.NewStorerMock(), - Cache: &mock.HeadersCacherStub{}, - Marshalizer: &mock.MarshalizerMock{}, - RequestHandler: &testscommon.RequestHandlerStub{}, + Storage: genericMocks.NewStorerMock(), + Cache: &mock.HeadersCacherStub{}, + ProofsPool: &dataRetrieverMocks.ProofsPoolMock{}, + Marshalizer: &mock.MarshalizerMock{}, + RequestHandler: &testscommon.RequestHandlerStub{}, + EnableEpochsHandler: &enableEpochsHandlerMock.EnableEpochsHandlerStub{}, } } diff --git a/update/sync/syncMiniBlocks_test.go b/update/sync/syncMiniBlocks_test.go index 9fc8f96db1f..3f1c00a4773 100644 --- a/update/sync/syncMiniBlocks_test.go +++ b/update/sync/syncMiniBlocks_test.go @@ -10,19 +10,21 @@ import ( "github.com/multiversx/mx-chain-core-go/core" "github.com/multiversx/mx-chain-core-go/data" "github.com/multiversx/mx-chain-core-go/data/block" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/update" "github.com/multiversx/mx-chain-go/update/mock" - "github.com/stretchr/testify/require" ) func createMockArgsPendingMiniBlock() ArgsNewPendingMiniBlocksSyncer { return ArgsNewPendingMiniBlocksSyncer{ Storage: &storageStubs.StorerStub{}, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, val interface{})) {}, }, Marshalizer: &mock.MarshalizerFake{}, @@ -93,7 +95,7 @@ func TestSyncPendingMiniBlocksFromMeta_MiniBlocksInPool(t *testing.T) { mb := &block.MiniBlock{} args := ArgsNewPendingMiniBlocksSyncer{ Storage: &storageStubs.StorerStub{}, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, val interface{})) {}, PeekCalled: func(key []byte) (value interface{}, ok bool) { miniBlockInPool = true @@ -147,7 +149,7 @@ func TestSyncPendingMiniBlocksFromMeta_MiniBlocksInPoolWithRewards(t *testing.T) } args := ArgsNewPendingMiniBlocksSyncer{ Storage: &storageStubs.StorerStub{}, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, val interface{})) {}, PeekCalled: func(key []byte) (value interface{}, ok bool) { miniBlockInPool = true @@ -223,7 +225,7 @@ func TestSyncPendingMiniBlocksFromMeta_MiniBlocksInPoolMissingTimeout(t *testing return nil, localErr }, }, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(f func(key []byte, val interface{})) {}, PeekCalled: func(key []byte) (value interface{}, ok bool) { return nil, false @@ -274,7 +276,7 @@ func TestSyncPendingMiniBlocksFromMeta_MiniBlocksInPoolReceive(t *testing.T) { return nil, localErr }, }, - Cache: testscommon.NewCacherMock(), + Cache: cache.NewCacherMock(), Marshalizer: &mock.MarshalizerFake{}, RequestHandler: &testscommon.RequestHandlerStub{}, } @@ -322,7 +324,7 @@ func TestSyncPendingMiniBlocksFromMeta_MiniBlocksInStorageReceive(t *testing.T) return mbBytes, nil }, }, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(_ func(_ []byte, _ interface{})) {}, PeekCalled: func(key []byte) (interface{}, bool) { return nil, false @@ -376,7 +378,7 @@ func TestSyncPendingMiniBlocksFromMeta_GetMiniBlocksShouldWork(t *testing.T) { return nil, localErr }, }, - Cache: &testscommon.CacherStub{ + Cache: &cache.CacherStub{ RegisterHandlerCalled: func(_ func(_ []byte, _ interface{})) {}, PeekCalled: func(key []byte) (interface{}, bool) { return nil, false diff --git a/update/sync/syncTransactions_test.go b/update/sync/syncTransactions_test.go index aa087bcbbe2..95ead49717f 100644 --- a/update/sync/syncTransactions_test.go +++ b/update/sync/syncTransactions_test.go @@ -16,17 +16,19 @@ import ( "github.com/multiversx/mx-chain-core-go/core/check" "github.com/multiversx/mx-chain-core-go/data/block" dataTransaction "github.com/multiversx/mx-chain-core-go/data/transaction" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/multiversx/mx-chain-go/dataRetriever" "github.com/multiversx/mx-chain-go/process" "github.com/multiversx/mx-chain-go/state" "github.com/multiversx/mx-chain-go/storage" "github.com/multiversx/mx-chain-go/testscommon" + "github.com/multiversx/mx-chain-go/testscommon/cache" dataRetrieverMock "github.com/multiversx/mx-chain-go/testscommon/dataRetriever" storageStubs "github.com/multiversx/mx-chain-go/testscommon/storage" "github.com/multiversx/mx-chain-go/update" "github.com/multiversx/mx-chain-go/update/mock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func createMockArgs() ArgsNewTransactionsSyncer { @@ -529,7 +531,7 @@ func TestTransactionsSync_GetValidatorInfoFromPoolShouldWork(t *testing.T) { ValidatorsInfoCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheID string) storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if bytes.Equal(key, txHash) { return nil, true @@ -690,7 +692,7 @@ func TestTransactionsSync_GetValidatorInfoFromPoolOrStorage(t *testing.T) { ValidatorsInfoCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheID string) storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { return nil, false }, @@ -852,7 +854,7 @@ func getDataPoolsWithShardValidatorInfoAndTxHash(svi *state.ShardValidatorInfo, ValidatorsInfoCalled: func() dataRetriever.ShardedDataCacherNotifier { return &testscommon.ShardedDataStub{ ShardDataStoreCalled: func(cacheID string) storage.Cacher { - return &testscommon.CacherStub{ + return &cache.CacherStub{ PeekCalled: func(key []byte) (value interface{}, ok bool) { if bytes.Equal(key, txHash) { return svi, true