diff --git a/consensus/broadcast/delayedBroadcast.go b/consensus/broadcast/delayedBroadcast.go index 950b8a30bf9..f56e65694c7 100644 --- a/consensus/broadcast/delayedBroadcast.go +++ b/consensus/broadcast/delayedBroadcast.go @@ -180,6 +180,9 @@ func (dbb *delayedBlockBroadcaster) SetHeaderForValidator(vData *shared.Validato return spos.ErrNilHeaderHash } + dbb.mutDataForBroadcast.Lock() + defer dbb.mutDataForBroadcast.Unlock() + log.Trace("delayedBlockBroadcaster.SetHeaderForValidator", "nbDelayedBroadcastData", len(dbb.delayedBroadcastData), "nbValBroadcastData", len(dbb.valBroadcastData), diff --git a/consensus/spos/consensusMessageValidator.go b/consensus/spos/consensusMessageValidator.go index 93c6977eed9..cdcf507cbbf 100644 --- a/consensus/spos/consensusMessageValidator.go +++ b/consensus/spos/consensusMessageValidator.go @@ -159,13 +159,13 @@ func (cmv *consensusMessageValidator) checkConsensusMessageValidity(cnsMsg *cons msgType := consensus.MessageType(cnsMsg.MsgType) - if cmv.consensusState.RoundIndex+1 < cnsMsg.RoundIndex { + if cmv.consensusState.GetRoundIndex()+1 < cnsMsg.RoundIndex { log.Trace("received message from consensus topic has a future round", "msg type", cmv.consensusService.GetStringValue(msgType), "from", cnsMsg.PubKey, "header hash", cnsMsg.BlockHeaderHash, "msg round", cnsMsg.RoundIndex, - "round", cmv.consensusState.RoundIndex, + "round", cmv.consensusState.GetRoundIndex(), ) return fmt.Errorf("%w : received message from consensus topic has a future round: %d", @@ -173,13 +173,13 @@ func (cmv *consensusMessageValidator) checkConsensusMessageValidity(cnsMsg *cons cnsMsg.RoundIndex) } - if cmv.consensusState.RoundIndex > cnsMsg.RoundIndex { + if cmv.consensusState.GetRoundIndex() > cnsMsg.RoundIndex { log.Trace("received message from consensus topic has a past round", "msg type", cmv.consensusService.GetStringValue(msgType), "from", cnsMsg.PubKey, "header hash", cnsMsg.BlockHeaderHash, "msg round", cnsMsg.RoundIndex, - "round", cmv.consensusState.RoundIndex, + "round", cmv.consensusState.GetRoundIndex(), ) return fmt.Errorf("%w : received message from consensus topic has a past round: %d", diff --git a/consensus/spos/consensusMessageValidator_test.go b/consensus/spos/consensusMessageValidator_test.go index 83dbf12057b..ef46fc9b75e 100644 --- a/consensus/spos/consensusMessageValidator_test.go +++ b/consensus/spos/consensusMessageValidator_test.go @@ -765,7 +765,7 @@ func TestCheckConsensusMessageValidity_ErrMessageForPastRound(t *testing.T) { t.Parallel() consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() - consensusMessageValidatorArgs.ConsensusState.RoundIndex = 100 + consensusMessageValidatorArgs.ConsensusState.SetRoundIndex(100) cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) headerBytes := make([]byte, 100) @@ -788,7 +788,7 @@ func TestCheckConsensusMessageValidity_ErrMessageTypeLimitReached(t *testing.T) t.Parallel() consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() - consensusMessageValidatorArgs.ConsensusState.RoundIndex = 10 + consensusMessageValidatorArgs.ConsensusState.SetRoundIndex(10) cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) pubKey := []byte(consensusMessageValidatorArgs.ConsensusState.ConsensusGroup()[0]) @@ -834,7 +834,7 @@ func createMockConsensusMessage(args spos.ArgsConsensusMessageValidator, pubKey MsgType: int64(msgType), PubKey: pubKey, Signature: createDummyByteSlice(SignatureSize), - RoundIndex: args.ConsensusState.RoundIndex, + RoundIndex: args.ConsensusState.GetRoundIndex(), BlockHeaderHash: createDummyByteSlice(args.HeaderHashSize), } } @@ -853,7 +853,7 @@ func TestCheckConsensusMessageValidity_InvalidSignature(t *testing.T) { consensusMessageValidatorArgs.PeerSignatureHandler = &mock.PeerSignatureHandler{ Signer: signer, } - consensusMessageValidatorArgs.ConsensusState.RoundIndex = 10 + consensusMessageValidatorArgs.ConsensusState.SetRoundIndex(10) cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) headerBytes := make([]byte, 100) @@ -876,7 +876,7 @@ func TestCheckConsensusMessageValidity_Ok(t *testing.T) { t.Parallel() consensusMessageValidatorArgs := createDefaultConsensusMessageValidatorArgs() - consensusMessageValidatorArgs.ConsensusState.RoundIndex = 10 + consensusMessageValidatorArgs.ConsensusState.SetRoundIndex(10) cmv, _ := spos.NewConsensusMessageValidator(consensusMessageValidatorArgs) headerBytes := make([]byte, 100) diff --git a/consensus/spos/consensusState.go b/consensus/spos/consensusState.go index a7a8ee3de65..8904717b7ea 100644 --- a/consensus/spos/consensusState.go +++ b/consensus/spos/consensusState.go @@ -42,6 +42,8 @@ type ConsensusState struct { *roundConsensus *roundThreshold *roundStatus + + mutState sync.RWMutex } // NewConsensusState creates a new ConsensusState object @@ -392,21 +394,33 @@ func (cns *ConsensusState) ResetRoundsWithoutReceivedMessages(pkBytes []byte, pi // GetRoundCanceled returns the state of the current round func (cns *ConsensusState) GetRoundCanceled() bool { + cns.mutState.RLock() + defer cns.mutState.RUnlock() + return cns.RoundCanceled } // SetRoundCanceled sets the state of the current round func (cns *ConsensusState) SetRoundCanceled(roundCanceled bool) { + cns.mutState.Lock() + defer cns.mutState.Unlock() + cns.RoundCanceled = roundCanceled } // GetRoundIndex returns the index of the current round func (cns *ConsensusState) GetRoundIndex() int64 { + cns.mutState.RLock() + defer cns.mutState.RUnlock() + return cns.RoundIndex } // SetRoundIndex sets the index of the current round func (cns *ConsensusState) SetRoundIndex(roundIndex int64) { + cns.mutState.Lock() + defer cns.mutState.Unlock() + cns.RoundIndex = roundIndex } @@ -447,11 +461,17 @@ func (cns *ConsensusState) GetHeader() data.HeaderHandler { // GetWaitingAllSignaturesTimeOut returns the state of the waiting all signatures time out func (cns *ConsensusState) GetWaitingAllSignaturesTimeOut() bool { + cns.mutState.RLock() + defer cns.mutState.RUnlock() + return cns.WaitingAllSignaturesTimeOut } // SetWaitingAllSignaturesTimeOut sets the state of the waiting all signatures time out func (cns *ConsensusState) SetWaitingAllSignaturesTimeOut(waitingAllSignaturesTimeOut bool) { + cns.mutState.Lock() + defer cns.mutState.Unlock() + cns.WaitingAllSignaturesTimeOut = waitingAllSignaturesTimeOut } diff --git a/consensus/spos/consensusState_test.go b/consensus/spos/consensusState_test.go index 1a0a1de6bdd..6125c4091c4 100644 --- a/consensus/spos/consensusState_test.go +++ b/consensus/spos/consensusState_test.go @@ -70,12 +70,12 @@ func TestConsensusState_ResetConsensusStateShouldWork(t *testing.T) { t.Parallel() cns := internalInitConsensusState() - cns.RoundCanceled = true - cns.ExtendedCalled = true - cns.WaitingAllSignaturesTimeOut = true + cns.SetRoundCanceled(true) + cns.SetExtendedCalled(true) + cns.SetWaitingAllSignaturesTimeOut(true) cns.ResetConsensusState() assert.False(t, cns.RoundCanceled) - assert.False(t, cns.ExtendedCalled) + assert.False(t, cns.GetExtendedCalled()) assert.False(t, cns.WaitingAllSignaturesTimeOut) } diff --git a/consensus/spos/roundConsensus.go b/consensus/spos/roundConsensus.go index 503eb0b2a2a..dfe6eb88d29 100644 --- a/consensus/spos/roundConsensus.go +++ b/consensus/spos/roundConsensus.go @@ -66,15 +66,18 @@ func (rcns *roundConsensus) SetEligibleList(eligibleList map[string]struct{}) { // ConsensusGroup returns the consensus group ID's func (rcns *roundConsensus) ConsensusGroup() []string { + rcns.mut.RLock() + defer rcns.mut.RUnlock() + return rcns.consensusGroup } // SetConsensusGroup sets the consensus group ID's func (rcns *roundConsensus) SetConsensusGroup(consensusGroup []string) { - rcns.consensusGroup = consensusGroup - rcns.mut.Lock() + rcns.consensusGroup = consensusGroup + rcns.validatorRoundStates = make(map[string]*roundState) for i := 0; i < len(consensusGroup); i++ { @@ -86,11 +89,17 @@ func (rcns *roundConsensus) SetConsensusGroup(consensusGroup []string) { // Leader returns the leader for the current consensus func (rcns *roundConsensus) Leader() string { + rcns.mut.RLock() + defer rcns.mut.RUnlock() + return rcns.leader } // SetLeader sets the leader for the current consensus func (rcns *roundConsensus) SetLeader(leader string) { + rcns.mut.Lock() + defer rcns.mut.Unlock() + rcns.leader = leader } @@ -156,6 +165,9 @@ func (rcns *roundConsensus) SelfJobDone(subroundId int) (bool, error) { // IsNodeInConsensusGroup method checks if the node is part of consensus group of the current round func (rcns *roundConsensus) IsNodeInConsensusGroup(node string) bool { + rcns.mut.RLock() + defer rcns.mut.RUnlock() + for i := 0; i < len(rcns.consensusGroup); i++ { if rcns.consensusGroup[i] == node { return true @@ -177,6 +189,9 @@ func (rcns *roundConsensus) IsNodeInEligibleList(node string) bool { // ComputeSize method returns the number of messages received from the nodes belonging to the current jobDone group // related to this subround func (rcns *roundConsensus) ComputeSize(subroundId int) int { + rcns.mut.RLock() + defer rcns.mut.RUnlock() + n := 0 for i := 0; i < len(rcns.consensusGroup); i++ { diff --git a/consensus/spos/subround_test.go b/consensus/spos/subround_test.go index cd54782643c..8eb3e8e568d 100644 --- a/consensus/spos/subround_test.go +++ b/consensus/spos/subround_test.go @@ -90,7 +90,7 @@ func initConsensusState() *spos.ConsensusState { ) cns.Data = []byte("X") - cns.RoundIndex = 0 + cns.SetRoundIndex(0) return cns } diff --git a/consensus/spos/worker.go b/consensus/spos/worker.go index 50fed737659..4407632143c 100644 --- a/consensus/spos/worker.go +++ b/consensus/spos/worker.go @@ -587,7 +587,7 @@ func (wrk *Worker) checkSelfState(cnsDta *consensus.Message) error { return ErrMessageFromItself } - if wrk.consensusState.RoundCanceled && wrk.consensusState.RoundIndex == cnsDta.RoundIndex { + if wrk.consensusState.GetRoundCanceled() && wrk.consensusState.GetRoundIndex() == cnsDta.RoundIndex { return ErrRoundCanceled } @@ -623,7 +623,7 @@ func (wrk *Worker) executeMessage(cnsDtaList []*consensus.Message) { if cnsDta == nil { continue } - if wrk.consensusState.RoundIndex != cnsDta.RoundIndex { + if wrk.consensusState.GetRoundIndex() != cnsDta.RoundIndex { continue } @@ -674,7 +674,7 @@ func (wrk *Worker) checkChannels(ctx context.Context) { // Extend does an extension for the subround with subroundId func (wrk *Worker) Extend(subroundId int) { - wrk.consensusState.ExtendedCalled = true + wrk.consensusState.SetExtendedCalled(true) log.Debug("extend function is called", "subround", wrk.consensusService.GetSubroundName(subroundId)) diff --git a/dataRetriever/dataPool/proofsCache/proofsPool.go b/dataRetriever/dataPool/proofsCache/proofsPool.go index 2ae8faca4c9..b0de8e005cd 100644 --- a/dataRetriever/dataPool/proofsCache/proofsPool.go +++ b/dataRetriever/dataPool/proofsCache/proofsPool.go @@ -33,7 +33,7 @@ func (pp *proofsPool) AddProof( shardID := headerProof.GetHeaderShardId() headerHash := headerProof.GetHeaderHash() - hasProof := pp.HasProof(shardID, headerProof.GetHeaderHash()) + hasProof := pp.HasProof(shardID, headerHash) if hasProof { log.Trace("there was already a valid proof for header, headerHash: %s", headerHash) return nil @@ -48,6 +48,14 @@ func (pp *proofsPool) AddProof( pp.cache[shardID] = proofsPerShard } + log.Trace("added proof to pool", + "header hash", headerProof.GetHeaderHash(), + "epoch", headerProof.GetHeaderEpoch(), + "nonce", headerProof.GetHeaderNonce(), + "shardID", headerProof.GetHeaderShardId(), + "pubKeys bitmap", headerProof.GetPubKeysBitmap(), + ) + proofsPerShard.addProof(headerProof) return nil @@ -67,6 +75,11 @@ func (pp *proofsPool) CleanupProofsBehindNonce(shardID uint32, nonce uint64) err return fmt.Errorf("%w: proofs cache per shard not found, shard ID: %d", ErrMissingProof, shardID) } + log.Trace("cleanup proofs behind nonce", + "nonce", nonce, + "shardID", shardID, + ) + proofsPerShard.cleanupProofsBehindNonce(nonce) return nil @@ -77,9 +90,18 @@ func (pp *proofsPool) GetProof( shardID uint32, headerHash []byte, ) (data.HeaderProofHandler, error) { + if headerHash == nil { + return nil, fmt.Errorf("nil header hash") + } + pp.mutCache.RLock() defer pp.mutCache.RUnlock() + log.Trace("trying to get proof", + "headerHash", headerHash, + "shardID", shardID, + ) + proofsPerShard, ok := pp.cache[shardID] if !ok { return nil, fmt.Errorf("%w: proofs cache per shard not found, shard ID: %d", ErrMissingProof, shardID) diff --git a/process/block/baseProcess.go b/process/block/baseProcess.go index d17140573c2..5ddb0608b1e 100644 --- a/process/block/baseProcess.go +++ b/process/block/baseProcess.go @@ -976,10 +976,14 @@ func (bp *baseProcessor) cleanupPools(headerHandler data.HeaderHandler) { highestPrevFinalBlockNonce, ) - err := bp.dataPool.Proofs().CleanupProofsBehindNonce(bp.shardCoordinator.SelfId(), highestPrevFinalBlockNonce) - if err != nil { - log.Warn("%w: failed to cleanup notarized proofs behind nonce %d on shardID %d", - err, noncesToPrevFinal, bp.shardCoordinator.SelfId()) + if bp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, headerHandler.GetEpoch()) { + err := bp.dataPool.Proofs().CleanupProofsBehindNonce(bp.shardCoordinator.SelfId(), highestPrevFinalBlockNonce) + if err != nil { + log.Warn("failed to cleanup notarized proofs behind nonce", + "nonce", noncesToPrevFinal, + "shardID", bp.shardCoordinator.SelfId(), + "error", err) + } } if bp.shardCoordinator.SelfId() == core.MetachainShardId { @@ -1011,10 +1015,14 @@ func (bp *baseProcessor) cleanupPoolsForCrossShard( crossNotarizedHeader.GetNonce(), ) - err = bp.dataPool.Proofs().CleanupProofsBehindNonce(shardID, noncesToPrevFinal) - if err != nil { - log.Warn("%w: failed to cleanup notarized proofs behind nonce %d on shardID %d", - err, noncesToPrevFinal, shardID) + if bp.enableEpochsHandler.IsFlagEnabledInEpoch(common.EquivalentMessagesFlag, crossNotarizedHeader.GetEpoch()) { + err = bp.dataPool.Proofs().CleanupProofsBehindNonce(shardID, noncesToPrevFinal) + if err != nil { + log.Warn("failed to cleanup notarized proofs behind nonce", + "nonce", noncesToPrevFinal, + "shardID", shardID, + "error", err) + } } } diff --git a/testscommon/consensus/initializers/initializers.go b/testscommon/consensus/initializers/initializers.go index aa3381281de..187c8f02892 100644 --- a/testscommon/consensus/initializers/initializers.go +++ b/testscommon/consensus/initializers/initializers.go @@ -92,7 +92,7 @@ func InitConsensusStateWithArgsVerifySignature(keysHandler consensus.KeysHandler rstatus, ) cns.Data = []byte("X") - cns.RoundIndex = 0 + cns.SetRoundIndex(0) return cns } @@ -151,6 +151,6 @@ func createConsensusStateWithNodes(eligibleNodesPubKeys map[string]struct{}, con ) cns.Data = []byte("X") - cns.RoundIndex = 0 + cns.SetRoundIndex(0) return cns } diff --git a/testscommon/dataRetriever/poolFactory.go b/testscommon/dataRetriever/poolFactory.go index b631e6d4ba2..54214ceedd0 100644 --- a/testscommon/dataRetriever/poolFactory.go +++ b/testscommon/dataRetriever/poolFactory.go @@ -226,6 +226,8 @@ func CreatePoolsHolderWithTxPool(txPool dataRetriever.ShardedDataCacherNotifier) heartbeatPool, err := storageunit.NewCache(cacherConfig) panicIfError("CreatePoolsHolderWithTxPool", err) + proofsPool := proofscache.NewProofsPool() + currentBlockTransactions := dataPool.NewCurrentBlockTransactionsPool() currentEpochValidatorInfo := dataPool.NewCurrentEpochValidatorInfoPool() dataPoolArgs := dataPool.DataPoolArgs{ @@ -243,7 +245,7 @@ func CreatePoolsHolderWithTxPool(txPool dataRetriever.ShardedDataCacherNotifier) PeerAuthentications: peerAuthPool, Heartbeats: heartbeatPool, ValidatorsInfo: validatorsInfo, - Proofs: &ProofsPoolMock{}, + Proofs: proofsPool, } holder, err := dataPool.NewDataPool(dataPoolArgs) panicIfError("CreatePoolsHolderWithTxPool", err)