Skip to content

Commit 084899f

Browse files
Protect against nil block
1 parent 39bc21f commit 084899f

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

storage/balance_storage.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,19 +461,30 @@ func (b *BalanceStorage) OrphanBalance(
461461
// retrieved once (like reconciliation).
462462
func (b *BalanceStorage) PruneBalances(
463463
ctx context.Context,
464-
dbTransaction DatabaseTransaction,
465464
account *types.AccountIdentifier,
466465
currency *types.Currency,
467466
index int64,
468467
) error {
469-
return b.removeHistoricalBalances(
468+
dbTx := b.db.NewDatabaseTransaction(ctx, true)
469+
defer dbTx.Discard(ctx)
470+
471+
err := b.removeHistoricalBalances(
470472
ctx,
471-
dbTransaction,
473+
dbTx,
472474
account,
473475
currency,
474476
index,
475477
false,
476478
)
479+
if err != nil {
480+
return fmt.Errorf("%w: unable to remove historical balances", err)
481+
}
482+
483+
if err := dbTx.Commit(ctx); err != nil {
484+
return fmt.Errorf("%w: unable to commit historical balance removal", err)
485+
}
486+
487+
return nil
477488
}
478489

479490
// UpdateBalance updates a types.AccountIdentifer
@@ -639,7 +650,9 @@ func (b *BalanceStorage) GetBalanceTransactional(
639650
currency *types.Currency,
640651
block *types.BlockIdentifier,
641652
) (*types.Amount, error) {
642-
// TODO: if block > head block, should return an error
653+
if block == nil {
654+
return nil, errors.New("block cannot be empty")
655+
}
643656

644657
key := GetAccountKey(account, currency)
645658
exists, acct, err := dbTx.Get(ctx, key)

storage/errors.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,15 @@ var (
324324
// to retrieve a pruned balance.
325325
ErrBalancePruned = errors.New("balance pruned")
326326

327+
// ErrBlockNil is returned when the block to lookup
328+
// a balance at is nil.
329+
ErrBlockNil = errors.New("block nil")
330+
327331
BalanceStorageErrs = []error{
328332
ErrNegativeBalance,
329333
ErrInvalidLiveBalance,
330334
ErrBalancePruned,
335+
ErrBlockNil,
331336
}
332337

333338
///////////////////

0 commit comments

Comments
 (0)