diff --git a/epochStart/shardchain/trigger.go b/epochStart/shardchain/trigger.go index f1de14fa7e1..e6100f3ab3b 100644 --- a/epochStart/shardchain/trigger.go +++ b/epochStart/shardchain/trigger.go @@ -649,15 +649,18 @@ func (t *trigger) shouldUpdateTrigger(metaHdr *block.MetaBlock, metaBlockHash [] return false } - isMetaStartOfEpochForCurrentEpoch := metaHdr.Epoch == t.epoch && metaHdr.IsStartOfEpochBlock() - if isMetaStartOfEpochForCurrentEpoch { + isMetaStartOfEpochForCurrentOrOlderEpoch := metaHdr.Epoch <= t.epoch && metaHdr.IsStartOfEpochBlock() + if isMetaStartOfEpochForCurrentOrOlderEpoch { return false } - if _, ok := t.mapHashHdr[string(metaBlockHash)]; ok { - return false - } - if _, ok := t.mapEpochStartHdrs[string(metaBlockHash)]; ok { + _, foundHdrInMap := t.mapHashHdr[string(metaBlockHash)] + _, foundHdrInEpochStartMap := t.mapEpochStartHdrs[string(metaBlockHash)] + + finalizedMetaBlockHash, ok := t.mapFinalizedEpochs[metaHdr.Epoch] + foundHdrInFinalizedMap := ok && bytes.Equal(metaBlockHash, []byte(finalizedMetaBlockHash)) + + if foundHdrInMap && foundHdrInEpochStartMap && foundHdrInFinalizedMap { return false } diff --git a/epochStart/shardchain/trigger_test.go b/epochStart/shardchain/trigger_test.go index be4c6ccb00e..4a5d2e40998 100644 --- a/epochStart/shardchain/trigger_test.go +++ b/epochStart/shardchain/trigger_test.go @@ -630,7 +630,7 @@ func TestTrigger_RevertStateToBlockBehindEpochStartNoBlockInAnEpoch(t *testing.T assert.Equal(t, et.epochStartShardHeader.GetEpoch(), prevEpochHdr.Epoch) } -func TestTrigger_ReceivedHeaderChangeEpochFinalityAttestingRound(t *testing.T) { +func TestTrigger_ReceivedEpochStartHeaderChangeEpochFinalityAttestingRound(t *testing.T) { t.Parallel() args := createMockShardEpochStartTriggerArguments() @@ -668,6 +668,104 @@ func TestTrigger_ReceivedHeaderChangeEpochFinalityAttestingRound(t *testing.T) { require.Equal(t, uint64(102), epochStartTrigger.EpochFinalityAttestingRound()) } +func TestTrigger_ReceivedHeaderChangeEpochWithoutPrevHeader(t *testing.T) { + t.Parallel() + + args := createMockShardEpochStartTriggerArguments() + args.Validity = 1 + args.Finality = 1 + + oldEpHeader := &block.MetaBlock{Nonce: 99, Round: 99, Epoch: 0} + oldHash, _ := core.CalculateHash(args.Marshalizer, args.Hasher, oldEpHeader) + + hash := []byte("hash") + epochStartHeader := &block.MetaBlock{Nonce: 100, Round: 100, Epoch: 1, PrevHash: oldHash} + epochStartHeader.EpochStart.LastFinalizedHeaders = []block.EpochStartShardData{{ShardID: 0, RootHash: hash, HeaderHash: hash}} + epochStartHash, _ := core.CalculateHash(args.Marshalizer, args.Hasher, epochStartHeader) + + nextHeader := &block.MetaBlock{Nonce: 101, Round: 101, Epoch: 1, PrevHash: epochStartHash} + nextHeaderHash, _ := core.CalculateHash(args.Marshalizer, args.Hasher, nextHeader) + + numGetHeadersFromPoolCalls := 0 + args.DataPool = &dataRetrieverMock.PoolsHolderStub{ + HeadersCalled: func() dataRetriever.HeadersPool { + return &mock.HeadersCacherStub{ + GetHeaderByHashCalled: func(hash []byte) (data.HeaderHandler, error) { + if bytes.Equal(hash, oldHash) { + if numGetHeadersFromPoolCalls == 0 { + numGetHeadersFromPoolCalls++ + return nil, errors.New("not found") + } + + return oldEpHeader, nil + } + + if bytes.Equal(hash, epochStartHash) { + return epochStartHeader, nil + } + if bytes.Equal(hash, nextHeaderHash) { + return nextHeader, nil + } + + return &block.MetaBlock{}, nil + }, + GetHeaderByNonceAndShardIdCalled: func(hdrNonce uint64, shardId uint32) ([]data.HeaderHandler, [][]byte, error) { + if hdrNonce == epochStartHeader.Nonce { + return []data.HeaderHandler{epochStartHeader}, [][]byte{epochStartHash}, nil + } + + if hdrNonce == nextHeader.Nonce { + return []data.HeaderHandler{nextHeader}, [][]byte{nextHeaderHash}, nil + } + + return make([]data.HeaderHandler, 0), make([][]byte, 0), nil + }, + } + }, + MiniBlocksCalled: func() storage.Cacher { + return cache.NewCacherStub() + }, + CurrEpochValidatorInfoCalled: func() dataRetriever.ValidatorInfoCacher { + return &vic.ValidatorInfoCacherStub{} + }, + ProofsCalled: func() dataRetriever.ProofsPool { + return &dataRetrieverMock.ProofsPoolMock{} + }, + } + + args.Storage = &storageStubs.ChainStorerStub{ + GetStorerCalled: func(unitType dataRetriever.UnitType) (storage.Storer, error) { + return &storageStubs.StorerStub{ + GetCalled: func(key []byte) (b []byte, err error) { + if bytes.Equal(key, oldHash) { + return nil, errors.New("failed to get header from storage") + } + + if bytes.Equal(key, nextHeaderHash) { + return nextHeaderHash, nil + } + + return []byte("hash"), nil + }, + PutCalled: func(key, data []byte) error { + return nil + }, + }, nil + }, + } + + epochStartTrigger, err := NewEpochStartTrigger(args) + require.Nil(t, err) + + epochStartTrigger.receivedMetaBlock(epochStartHeader, epochStartHash) + + require.False(t, epochStartTrigger.isEpochStart) + + epochStartTrigger.receivedMetaBlock(epochStartHeader, epochStartHash) + + require.True(t, epochStartTrigger.isEpochStart) +} + func TestTrigger_ClearMissingValidatorsInfoMapShouldWork(t *testing.T) { t.Parallel()