diff --git a/storage/modules/block_storage.go b/storage/modules/block_storage.go index 4d41bafd9..078c613e4 100644 --- a/storage/modules/block_storage.go +++ b/storage/modules/block_storage.go @@ -21,6 +21,7 @@ import ( "log" "runtime" "strconv" + "strings" "github.com/neilotoole/errgroup" @@ -77,9 +78,29 @@ func getBlockIndexKey(index int64) []byte { return []byte(fmt.Sprintf("%s/%d", blockIndexNamespace, index)) } -func getTransactionHashKey(transactionIdentifier *types.TransactionIdentifier) (string, []byte) { +func getTransactionKey( + blockIdentifier *types.BlockIdentifier, + transactionIdentifier *types.TransactionIdentifier, +) (string, []byte) { return transactionNamespace, []byte( - fmt.Sprintf("%s/%s", transactionNamespace, transactionIdentifier.Hash), + fmt.Sprintf( + "%s/%s/%s", + transactionNamespace, + transactionIdentifier.Hash, + blockIdentifier.Hash, + ), + ) +} + +func getTransactionPrefix( + transactionIdentifier *types.TransactionIdentifier, +) []byte { + return []byte( + fmt.Sprintf( + "%s/%s/", + transactionNamespace, + transactionIdentifier.Hash, + ), ) } @@ -230,10 +251,24 @@ func (b *BlockStorage) pruneBlock( if err == nil { blockIdentifier := blockResponse.Block.BlockIdentifier - for _, tx := range blockResponse.OtherTransactions { - if err := b.pruneTransaction(ctx, dbTx, blockIdentifier, tx); err != nil { - return -1, fmt.Errorf("%w: %v", storageErrs.ErrCannotPruneTransaction, err) - } + // Remove all transaction hashes + g, gctx := errgroup.WithContextN(ctx, b.numCPU, b.numCPU) + for i := range blockResponse.OtherTransactions { + // We need to set variable before calling goroutine + // to avoid getting an updated pointer as loop iteration + // continues. + tx := blockResponse.OtherTransactions[i] + g.Go(func() error { + if err := b.pruneTransaction(gctx, dbTx, blockIdentifier, tx); err != nil { + return fmt.Errorf("%w: %v", storageErrs.ErrCannotPruneTransaction, err) + } + + return nil + }) + } + + if err := g.Wait(); err != nil { + return -1, err } _, blockKey := getBlockHashKey(blockIdentifier.Hash) @@ -846,54 +881,24 @@ func (b *BlockStorage) CreateBlockCache(ctx context.Context, blocks int) []*type return cache } -func (b *BlockStorage) updateTransaction( - ctx context.Context, - dbTx database.Transaction, - hashKey []byte, - namespace string, - blocks map[string]*blockTransaction, -) error { - encodedResult, err := b.db.Encoder().Encode(namespace, blocks) - if err != nil { - return fmt.Errorf("%w: %v", storageErrs.ErrTransactionDataEncodeFailed, err) - } - - if err := dbTx.Set(ctx, hashKey, encodedResult, true); err != nil { - return err - } - - return nil -} - func (b *BlockStorage) storeTransaction( ctx context.Context, transaction database.Transaction, blockIdentifier *types.BlockIdentifier, tx *types.Transaction, ) error { - namespace, hashKey := getTransactionHashKey(tx.TransactionIdentifier) - exists, val, err := transaction.Get(ctx, hashKey) - if err != nil { - return err - } - - var blocks map[string]*blockTransaction - if !exists { - blocks = make(map[string]*blockTransaction) - } else { - err := b.db.Encoder().Decode(namespace, val, &blocks, true) - if err != nil { - return fmt.Errorf("%w: could not decode transaction hash contents", err) - } - } - // We check for duplicates before storing transaction, - // so this must be a new key. - blocks[blockIdentifier.Hash] = &blockTransaction{ + namespace, hashKey := getTransactionKey(blockIdentifier, tx.TransactionIdentifier) + bt := &blockTransaction{ Transaction: tx, BlockIndex: blockIdentifier.Index, } - return b.updateTransaction(ctx, transaction, hashKey, namespace, blocks) + encodedResult, err := b.db.Encoder().Encode(namespace, bt) + if err != nil { + return fmt.Errorf("%w: %v", storageErrs.ErrTransactionDataEncodeFailed, err) + } + + return storeUniqueKey(ctx, transaction, hashKey, encodedResult, true) } func (b *BlockStorage) pruneTransaction( @@ -902,25 +907,17 @@ func (b *BlockStorage) pruneTransaction( blockIdentifier *types.BlockIdentifier, txIdentifier *types.TransactionIdentifier, ) error { - namespace, hashKey := getTransactionHashKey(txIdentifier) - exists, val, err := transaction.Get(ctx, hashKey) - if err != nil { - return err - } - if !exists { - return storageErrs.ErrTransactionNotFound - } - - var blocks map[string]*blockTransaction - if err := b.db.Encoder().Decode(namespace, val, &blocks, true); err != nil { - return fmt.Errorf("%w: could not decode transaction hash contents", err) + namespace, hashKey := getTransactionKey(blockIdentifier, txIdentifier) + bt := &blockTransaction{ + BlockIndex: blockIdentifier.Index, } - blocks[blockIdentifier.Hash] = &blockTransaction{ - BlockIndex: blockIdentifier.Index, + encodedResult, err := b.db.Encoder().Encode(namespace, bt) + if err != nil { + return fmt.Errorf("%w: %v", storageErrs.ErrTransactionDataEncodeFailed, err) } - return b.updateTransaction(ctx, transaction, hashKey, namespace, blocks) + return transaction.Set(ctx, hashKey, encodedResult, true) } func (b *BlockStorage) removeTransaction( @@ -929,36 +926,49 @@ func (b *BlockStorage) removeTransaction( blockIdentifier *types.BlockIdentifier, transactionIdentifier *types.TransactionIdentifier, ) error { - namespace, hashKey := getTransactionHashKey(transactionIdentifier) - exists, val, err := transaction.Get(ctx, hashKey) - if err != nil { - return err - } + _, hashKey := getTransactionKey(blockIdentifier, transactionIdentifier) - if !exists { - return fmt.Errorf( - "%w %s", - storageErrs.ErrTransactionDeleteFailed, - transactionIdentifier.Hash, - ) - } - - var blocks map[string]*blockTransaction - if err := b.db.Encoder().Decode(namespace, val, &blocks, true); err != nil { - return fmt.Errorf("%w: could not decode transaction hash contents", err) - } - - if _, exists := blocks[blockIdentifier.Hash]; !exists { - return fmt.Errorf("%w %s", storageErrs.ErrTransactionHashNotFound, blockIdentifier.Hash) - } + return transaction.Delete(ctx, hashKey) +} - delete(blocks, blockIdentifier.Hash) +func (b *BlockStorage) getAllTransactionsByIdentifier( + ctx context.Context, + transactionIdentifier *types.TransactionIdentifier, + txn database.Transaction, +) ([]*types.BlockTransaction, error) { + blockTransactions := []*types.BlockTransaction{} + _, err := txn.Scan( + ctx, + getTransactionPrefix(transactionIdentifier), + getTransactionPrefix(transactionIdentifier), + func(k []byte, v []byte) error { + // Decode blockTransaction + var bt blockTransaction + if err := b.db.Encoder().Decode(transactionNamespace, v, &bt, false); err != nil { + return fmt.Errorf("%w: unable to decode block data for transaction", err) + } - if len(blocks) == 0 { - return transaction.Delete(ctx, hashKey) + // Extract hash from key + splitKey := strings.Split(string(k), "/") + blockHash := splitKey[len(splitKey)-1] + + blockTransactions = append(blockTransactions, &types.BlockTransaction{ + BlockIdentifier: &types.BlockIdentifier{ + Index: bt.BlockIndex, + Hash: blockHash, + }, + Transaction: bt.Transaction, + }) + return nil + }, + false, + false, + ) + if err != nil { + return nil, err } - return b.updateTransaction(ctx, transaction, hashKey, namespace, blocks) + return blockTransactions, nil } // FindTransaction returns the most recent *types.BlockIdentifier containing the @@ -968,27 +978,20 @@ func (b *BlockStorage) FindTransaction( transactionIdentifier *types.TransactionIdentifier, txn database.Transaction, ) (*types.BlockIdentifier, *types.Transaction, error) { - namespace, key := getTransactionHashKey(transactionIdentifier) - txExists, tx, err := txn.Get(ctx, key) + blockTransactions, err := b.getAllTransactionsByIdentifier(ctx, transactionIdentifier, txn) if err != nil { return nil, nil, fmt.Errorf("%w: %v", storageErrs.ErrTransactionDBQueryFailed, err) } - if !txExists { + if len(blockTransactions) == 0 { return nil, nil, nil } - var blocks map[string]*blockTransaction - if err := b.db.Encoder().Decode(namespace, tx, &blocks, true); err != nil { - return nil, nil, fmt.Errorf("%w: unable to decode block data for transaction", err) - } - var newestBlock *types.BlockIdentifier var newestTransaction *types.Transaction - for hash, blockTransaction := range blocks { - b := &types.BlockIdentifier{Hash: hash, Index: blockTransaction.BlockIndex} - if newestBlock == nil || blockTransaction.BlockIndex > newestBlock.Index { - newestBlock = b + for _, blockTransaction := range blockTransactions { + if newestBlock == nil || blockTransaction.BlockIdentifier.Index > newestBlock.Index { + newestBlock = blockTransaction.BlockIdentifier newestTransaction = blockTransaction.Transaction } } @@ -1016,7 +1019,7 @@ func (b *BlockStorage) findBlockTransaction( return nil, storageErrs.ErrCannotAccessPrunedData } - namespace, key := getTransactionHashKey(transactionIdentifier) + namespace, key := getTransactionKey(blockIdentifier, transactionIdentifier) txExists, tx, err := txn.Get(ctx, key) if err != nil { return nil, fmt.Errorf("%w: %v", storageErrs.ErrTransactionDBQueryFailed, err) @@ -1030,22 +1033,12 @@ func (b *BlockStorage) findBlockTransaction( ) } - var blocks map[string]*blockTransaction - if err := b.db.Encoder().Decode(namespace, tx, &blocks, true); err != nil { + var bt blockTransaction + if err := b.db.Encoder().Decode(namespace, tx, &bt, true); err != nil { return nil, fmt.Errorf("%w: unable to decode block data for transaction", err) } - val, ok := blocks[blockIdentifier.Hash] - if !ok { - return nil, fmt.Errorf( - "%w: did not find transaction %s in block %s", - storageErrs.ErrTransactionDoesNotExistInBlock, - transactionIdentifier.Hash, - blockIdentifier.Hash, - ) - } - - return val.Transaction, nil + return bt.Transaction, nil } // GetBlockTransaction retrieves a transaction belonging to a certain