Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 103 additions & 110 deletions storage/modules/block_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"log"
"runtime"
"strconv"
"strings"

"github.com/neilotoole/errgroup"

Expand Down Expand Up @@ -77,9 +78,29 @@ func getBlockIndexKey(index int64) []byte {
return []byte(fmt.Sprintf("%s/%d", blockIndexNamespace, index))
}

func getTransactionHashKey(transactionIdentifier *types.TransactionIdentifier) (string, []byte) {
func getTransactionKey(
blockIdentifier *types.BlockIdentifier,
transactionIdentifier *types.TransactionIdentifier,
) (string, []byte) {
return transactionNamespace, []byte(
fmt.Sprintf("%s/%s", transactionNamespace, transactionIdentifier.Hash),
fmt.Sprintf(
"%s/%s/%s",
transactionNamespace,
transactionIdentifier.Hash,
blockIdentifier.Hash,
),
)
}

func getTransactionPrefix(
transactionIdentifier *types.TransactionIdentifier,
) []byte {
return []byte(
fmt.Sprintf(
"%s/%s/",
transactionNamespace,
transactionIdentifier.Hash,
),
)
}

Expand Down Expand Up @@ -230,10 +251,24 @@ func (b *BlockStorage) pruneBlock(
if err == nil {
blockIdentifier := blockResponse.Block.BlockIdentifier

for _, tx := range blockResponse.OtherTransactions {
if err := b.pruneTransaction(ctx, dbTx, blockIdentifier, tx); err != nil {
return -1, fmt.Errorf("%w: %v", storageErrs.ErrCannotPruneTransaction, err)
}
// Remove all transaction hashes
g, gctx := errgroup.WithContextN(ctx, b.numCPU, b.numCPU)
for i := range blockResponse.OtherTransactions {
// We need to set variable before calling goroutine
// to avoid getting an updated pointer as loop iteration
// continues.
tx := blockResponse.OtherTransactions[i]
g.Go(func() error {
if err := b.pruneTransaction(gctx, dbTx, blockIdentifier, tx); err != nil {
return fmt.Errorf("%w: %v", storageErrs.ErrCannotPruneTransaction, err)
}

return nil
})
}

if err := g.Wait(); err != nil {
return -1, err
}

_, blockKey := getBlockHashKey(blockIdentifier.Hash)
Expand Down Expand Up @@ -846,54 +881,24 @@ func (b *BlockStorage) CreateBlockCache(ctx context.Context, blocks int) []*type
return cache
}

func (b *BlockStorage) updateTransaction(
ctx context.Context,
dbTx database.Transaction,
hashKey []byte,
namespace string,
blocks map[string]*blockTransaction,
) error {
encodedResult, err := b.db.Encoder().Encode(namespace, blocks)
if err != nil {
return fmt.Errorf("%w: %v", storageErrs.ErrTransactionDataEncodeFailed, err)
}

if err := dbTx.Set(ctx, hashKey, encodedResult, true); err != nil {
return err
}

return nil
}

func (b *BlockStorage) storeTransaction(
ctx context.Context,
transaction database.Transaction,
blockIdentifier *types.BlockIdentifier,
tx *types.Transaction,
) error {
namespace, hashKey := getTransactionHashKey(tx.TransactionIdentifier)
exists, val, err := transaction.Get(ctx, hashKey)
if err != nil {
return err
}

var blocks map[string]*blockTransaction
if !exists {
blocks = make(map[string]*blockTransaction)
} else {
err := b.db.Encoder().Decode(namespace, val, &blocks, true)
if err != nil {
return fmt.Errorf("%w: could not decode transaction hash contents", err)
}
}
// We check for duplicates before storing transaction,
// so this must be a new key.
blocks[blockIdentifier.Hash] = &blockTransaction{
namespace, hashKey := getTransactionKey(blockIdentifier, tx.TransactionIdentifier)
bt := &blockTransaction{
Transaction: tx,
BlockIndex: blockIdentifier.Index,
}

return b.updateTransaction(ctx, transaction, hashKey, namespace, blocks)
encodedResult, err := b.db.Encoder().Encode(namespace, bt)
if err != nil {
return fmt.Errorf("%w: %v", storageErrs.ErrTransactionDataEncodeFailed, err)
}

return storeUniqueKey(ctx, transaction, hashKey, encodedResult, true)
}

func (b *BlockStorage) pruneTransaction(
Expand All @@ -902,25 +907,17 @@ func (b *BlockStorage) pruneTransaction(
blockIdentifier *types.BlockIdentifier,
txIdentifier *types.TransactionIdentifier,
) error {
namespace, hashKey := getTransactionHashKey(txIdentifier)
exists, val, err := transaction.Get(ctx, hashKey)
if err != nil {
return err
}
if !exists {
return storageErrs.ErrTransactionNotFound
}

var blocks map[string]*blockTransaction
if err := b.db.Encoder().Decode(namespace, val, &blocks, true); err != nil {
return fmt.Errorf("%w: could not decode transaction hash contents", err)
namespace, hashKey := getTransactionKey(blockIdentifier, txIdentifier)
bt := &blockTransaction{
BlockIndex: blockIdentifier.Index,
}

blocks[blockIdentifier.Hash] = &blockTransaction{
BlockIndex: blockIdentifier.Index,
encodedResult, err := b.db.Encoder().Encode(namespace, bt)
if err != nil {
return fmt.Errorf("%w: %v", storageErrs.ErrTransactionDataEncodeFailed, err)
}

return b.updateTransaction(ctx, transaction, hashKey, namespace, blocks)
return transaction.Set(ctx, hashKey, encodedResult, true)
}

func (b *BlockStorage) removeTransaction(
Expand All @@ -929,36 +926,49 @@ func (b *BlockStorage) removeTransaction(
blockIdentifier *types.BlockIdentifier,
transactionIdentifier *types.TransactionIdentifier,
) error {
namespace, hashKey := getTransactionHashKey(transactionIdentifier)
exists, val, err := transaction.Get(ctx, hashKey)
if err != nil {
return err
}
_, hashKey := getTransactionKey(blockIdentifier, transactionIdentifier)

if !exists {
return fmt.Errorf(
"%w %s",
storageErrs.ErrTransactionDeleteFailed,
transactionIdentifier.Hash,
)
}

var blocks map[string]*blockTransaction
if err := b.db.Encoder().Decode(namespace, val, &blocks, true); err != nil {
return fmt.Errorf("%w: could not decode transaction hash contents", err)
}

if _, exists := blocks[blockIdentifier.Hash]; !exists {
return fmt.Errorf("%w %s", storageErrs.ErrTransactionHashNotFound, blockIdentifier.Hash)
}
return transaction.Delete(ctx, hashKey)
}

delete(blocks, blockIdentifier.Hash)
func (b *BlockStorage) getAllTransactionsByIdentifier(
ctx context.Context,
transactionIdentifier *types.TransactionIdentifier,
txn database.Transaction,
) ([]*types.BlockTransaction, error) {
blockTransactions := []*types.BlockTransaction{}
_, err := txn.Scan(
ctx,
getTransactionPrefix(transactionIdentifier),
getTransactionPrefix(transactionIdentifier),
func(k []byte, v []byte) error {
// Decode blockTransaction
var bt blockTransaction
if err := b.db.Encoder().Decode(transactionNamespace, v, &bt, false); err != nil {
return fmt.Errorf("%w: unable to decode block data for transaction", err)
}

if len(blocks) == 0 {
return transaction.Delete(ctx, hashKey)
// Extract hash from key
splitKey := strings.Split(string(k), "/")
blockHash := splitKey[len(splitKey)-1]

blockTransactions = append(blockTransactions, &types.BlockTransaction{
BlockIdentifier: &types.BlockIdentifier{
Index: bt.BlockIndex,
Hash: blockHash,
},
Transaction: bt.Transaction,
})
return nil
},
false,
false,
)
if err != nil {
return nil, err
}

return b.updateTransaction(ctx, transaction, hashKey, namespace, blocks)
return blockTransactions, nil
}

// FindTransaction returns the most recent *types.BlockIdentifier containing the
Expand All @@ -968,27 +978,20 @@ func (b *BlockStorage) FindTransaction(
transactionIdentifier *types.TransactionIdentifier,
txn database.Transaction,
) (*types.BlockIdentifier, *types.Transaction, error) {
namespace, key := getTransactionHashKey(transactionIdentifier)
txExists, tx, err := txn.Get(ctx, key)
blockTransactions, err := b.getAllTransactionsByIdentifier(ctx, transactionIdentifier, txn)
if err != nil {
return nil, nil, fmt.Errorf("%w: %v", storageErrs.ErrTransactionDBQueryFailed, err)
}

if !txExists {
if len(blockTransactions) == 0 {
return nil, nil, nil
}

var blocks map[string]*blockTransaction
if err := b.db.Encoder().Decode(namespace, tx, &blocks, true); err != nil {
return nil, nil, fmt.Errorf("%w: unable to decode block data for transaction", err)
}

var newestBlock *types.BlockIdentifier
var newestTransaction *types.Transaction
for hash, blockTransaction := range blocks {
b := &types.BlockIdentifier{Hash: hash, Index: blockTransaction.BlockIndex}
if newestBlock == nil || blockTransaction.BlockIndex > newestBlock.Index {
newestBlock = b
for _, blockTransaction := range blockTransactions {
if newestBlock == nil || blockTransaction.BlockIdentifier.Index > newestBlock.Index {
newestBlock = blockTransaction.BlockIdentifier
newestTransaction = blockTransaction.Transaction
}
}
Expand Down Expand Up @@ -1016,7 +1019,7 @@ func (b *BlockStorage) findBlockTransaction(
return nil, storageErrs.ErrCannotAccessPrunedData
}

namespace, key := getTransactionHashKey(transactionIdentifier)
namespace, key := getTransactionKey(blockIdentifier, transactionIdentifier)
txExists, tx, err := txn.Get(ctx, key)
if err != nil {
return nil, fmt.Errorf("%w: %v", storageErrs.ErrTransactionDBQueryFailed, err)
Expand All @@ -1030,22 +1033,12 @@ func (b *BlockStorage) findBlockTransaction(
)
}

var blocks map[string]*blockTransaction
if err := b.db.Encoder().Decode(namespace, tx, &blocks, true); err != nil {
var bt blockTransaction
if err := b.db.Encoder().Decode(namespace, tx, &bt, true); err != nil {
return nil, fmt.Errorf("%w: unable to decode block data for transaction", err)
}

val, ok := blocks[blockIdentifier.Hash]
if !ok {
return nil, fmt.Errorf(
"%w: did not find transaction %s in block %s",
storageErrs.ErrTransactionDoesNotExistInBlock,
transactionIdentifier.Hash,
blockIdentifier.Hash,
)
}

return val.Transaction, nil
return bt.Transaction, nil
}

// GetBlockTransaction retrieves a transaction belonging to a certain
Expand Down