@@ -2,11 +2,13 @@ package treestore
22
33import (
44 "bytes"
5+ "context"
56 "crypto/sha256"
67 "encoding/hex"
78 "fmt"
89 "log"
910
11+ "github.com/jackc/pgx/v5"
1012 pocketLogger "github.com/pokt-network/pocket/logger"
1113 "github.com/pokt-network/pocket/persistence/indexer"
1214 "github.com/pokt-network/pocket/persistence/kvstore"
@@ -138,7 +140,7 @@ func newMemStateTrees() (*StateTrees, error) {
138140 return stateTrees , nil
139141}
140142
141- func (s * StateTrees ) UpdateMerkleTrees (txi indexer.TxIndexer ) (string , error ) {
143+ func (s * StateTrees ) UpdateMerkleTrees (ctx context. Context , tx pgx. Tx , txi indexer.TxIndexer ) (string , error ) {
142144 // Update all the merkle trees
143145 for treeType := merkleTree (0 ); treeType < numMerkleTrees ; treeType ++ {
144146 switch treeType {
@@ -148,7 +150,7 @@ func (s *StateTrees) UpdateMerkleTrees(txi indexer.TxIndexer) (string, error) {
148150 if ! ok {
149151 return "" , fmt .Errorf ("no actor type found for merkle tree: %v" , treeType )
150152 }
151- if err := s .updateActorsTree (actorType ); err != nil {
153+ if err := s .updateActorsTree (ctx , tx , actorType ); err != nil {
152154 return "" , fmt .Errorf ("failed to update actors tree for treeType: %v, actorType: %v - %w" , treeType , actorType , err )
153155 }
154156
@@ -172,7 +174,7 @@ func (s *StateTrees) UpdateMerkleTrees(txi indexer.TxIndexer) (string, error) {
172174 return "" , fmt .Errorf ("failed to update params tree - %w" , err )
173175 }
174176 case flagsMerkleTree :
175- if err := s .updateFlagsTree (); err != nil {
177+ if err := s .updateFlagsTree ([] * coreTypes. Flag { /* TODO_IN_THIS_COMMIT */ } ); err != nil {
176178 return "" , fmt .Errorf ("failed to update flags tree - %w" , err )
177179 }
178180
@@ -202,9 +204,9 @@ func (s *StateTrees) getStateHash() string {
202204
203205// Actor Tree Helpers
204206
205- func (s * StateTrees ) updateActorsTree (actorType coreTypes.ActorType ) error {
207+ func (s * StateTrees ) updateActorsTree (ctx context. Context , tx pgx. Tx , actorType coreTypes.ActorType ) error {
206208
207- actors , err := s .getActorsUpdatedAtHeight (actorType , s .Height ) // TODO IN THIS COMMIT deal with height for state trees component
209+ actors , err := s .getActorsUpdatedAtHeight (ctx , tx , actorType , s .Height ) // TODO IN THIS COMMIT deal with height for state trees component
208210 if err != nil {
209211 return err
210212 }
@@ -232,13 +234,13 @@ func (s *StateTrees) updateActorsTree(actorType coreTypes.ActorType) error {
232234 return nil
233235}
234236
235- func (s * StateTrees ) getActorsUpdatedAtHeight (actorType coreTypes.ActorType , height int64 ) (actors []* coreTypes.Actor , err error ) {
237+ func (s * StateTrees ) getActorsUpdatedAtHeight (ctx context. Context , tx pgx. Tx , actorType coreTypes.ActorType , height int64 ) (actors []* coreTypes.Actor , err error ) {
236238 actorSchema , ok := actorTypeToSchemaName [actorType ]
237239 if ! ok {
238240 return nil , fmt .Errorf ("no schema found for actor type: %s" , actorType )
239241 }
240242
241- schemaActors , err := s .GetActorsUpdated (actorSchema , height )
243+ schemaActors , err := s .GetActorsUpdated (ctx , tx , actorSchema , height )
242244 if err != nil {
243245 return nil , err
244246 }
@@ -265,11 +267,6 @@ func (s *StateTrees) getActorsUpdatedAtHeight(actorType coreTypes.ActorType, hei
265267
266268// TODO_IN_THIS_COMMIT figure out how to pass accounts to this function
267269func (s * StateTrees ) updateAccountTrees (accounts []* coreTypes.Account ) error {
268- // accounts, err := s.GetAccountsUpdated(s.Height)
269- // if err != nil {
270- // return err
271- // }
272-
273270 for _ , account := range accounts {
274271 bzAddr , err := hex .DecodeString (account .GetAddress ())
275272 if err != nil {
@@ -290,26 +287,22 @@ func (s *StateTrees) updateAccountTrees(accounts []*coreTypes.Account) error {
290287}
291288
292289func (s * StateTrees ) updatePoolTrees (pools []* coreTypes.Account ) error {
293- // pools, err := s.GetPoolsUpdated(s.Height)
294- // if err != nil {
295- // return err
296- // }
297-
298290 for _ , pool := range pools {
299291 log .Printf ("pool %+v" , pool )
300- // bzAddr, err := hex.DecodeString(pool.GetAddress())
301- // if err != nil {
302- // return err
303- // }
304-
305- // accBz, err := s.merkleTrees[accountMerkleTree].
306- // if err != nil {
307- // return err
308- // }
309-
310- // if _, err := s.merkleTrees[poolMerkleTree].Update(bzAddr, accBz); err != nil {
311- // return err
312- // }
292+ bzAddr , err := hex .DecodeString (pool .GetAddress ())
293+ if err != nil {
294+ return err
295+ }
296+
297+ // TODO verify this is the correct logic here - these trees are complicated AF
298+ accBz , err := s .merkleTrees [accountMerkleTree ].Get (bzAddr )
299+ if err != nil {
300+ return err
301+ }
302+
303+ if _ , err := s .merkleTrees [poolMerkleTree ].Update (bzAddr , accBz ); err != nil {
304+ return err
305+ }
313306 }
314307
315308 return nil
@@ -349,25 +342,19 @@ func (s *StateTrees) updateParamsTree(params []*coreTypes.Param) error {
349342 return nil
350343}
351344
352- func (s * StateTrees ) updateFlagsTree () error {
353- // flags, err := s.getFlagsUpdated(s.Height)
354- // if err != nil {
355- // return err
356- // }
357-
358- // for _, flag := range flags {
359- // flagBz, err := codec.GetCodec().Marshal(flag)
360- // flagKey := crypto.SHA3Hash([]byte(flag.Name))
361- // if err != nil {
362- // return err
363- // }
364- // if _, err := s.merkleTrees[flagsMerkleTree].Update(flagKey, flagBz); err != nil {
365- // return err
366- // }
367- // }
368-
369- // return nil
370- return fmt .Errorf ("not impl" )
345+ func (s * StateTrees ) updateFlagsTree (flags []* coreTypes.Flag ) error {
346+ for _ , flag := range flags {
347+ flagBz , err := codec .GetCodec ().Marshal (flag )
348+ flagKey := crypto .SHA3Hash ([]byte (flag .Name ))
349+ if err != nil {
350+ return err
351+ }
352+ if _ , err := s .merkleTrees [flagsMerkleTree ].Update (flagKey , flagBz ); err != nil {
353+ return err
354+ }
355+ }
356+
357+ return nil
371358}
372359
373360func (s * StateTrees ) ClearAllTreeState () error {
@@ -389,37 +376,52 @@ func (s *StateTrees) ClearAllTreeState() error {
389376 return nil
390377}
391378
392- func (s * StateTrees ) GetActorsUpdated (actorSchema types.ProtocolActorSchema , height int64 ) (actors []* coreTypes.Actor , err error ) {
393- // ctx, tx := s.getCtxAndTx()
394-
395- // rows, err := tx.Query(ctx, actorSchema.GetUpdatedAtHeightQuery(height))
396- // if err != nil {
397- // return nil, err
398- // }
399- // defer rows.Close()
400-
401- // addrs := make([][]byte, 0)
402- // for rows.Next() {
403- // var addr string
404- // if err := rows.Scan(&addr); err != nil {
405- // return nil, err
406- // }
407- // addrBz, err := hex.DecodeString(addr)
408- // if err != nil {
409- // return nil, err
410- // }
411- // addrs = append(addrs, addrBz)
412- // }
413- // rows.Close()
414-
415- // actors = make([]*coreTypes.Actor, len(addrs))
416- // for i, addr := range addrs {
417- // actor, err := s.getActor(actorSchema, addr, height)
418- // if err != nil {
419- // return nil, err
420- // }
421- // actors[i] = actor
422- // }
379+ func (s * StateTrees ) GetActorsUpdated (
380+ ctx context.Context ,
381+ tx pgx.Tx ,
382+ actorSchema types.ProtocolActorSchema ,
383+ height int64 ,
384+ ) (actors []* coreTypes.Actor , err error ) {
385+ rows , err := tx .Query (ctx , actorSchema .GetUpdatedAtHeightQuery (height ))
386+ if err != nil {
387+ return nil , err
388+ }
389+ defer rows .Close ()
390+
391+ addrs := make ([][]byte , 0 )
392+ for rows .Next () {
393+ var addr string
394+ if err := rows .Scan (& addr ); err != nil {
395+ return nil , err
396+ }
397+ addrBz , err := hex .DecodeString (addr )
398+ if err != nil {
399+ return nil , err
400+ }
401+ addrs = append (addrs , addrBz )
402+ }
403+ rows .Close ()
404+
405+ actors = make ([]* coreTypes.Actor , len (addrs ))
406+ for i , addr := range addrs {
407+ actor , err := s .getActor (ctx , tx , actorSchema , addr , height )
408+ if err != nil {
409+ return nil , err
410+ }
411+ actors [i ] = actor
412+ }
413+
414+ return actors , nil
415+ }
423416
417+ // getActor returns an actor from an address from the database
418+ // during an open transaction.
419+ func (s * StateTrees ) getActor (
420+ ctx context.Context ,
421+ tx pgx.Tx ,
422+ actorSchema types.ProtocolActorSchema ,
423+ addr []byte ,
424+ height int64 ,
425+ ) (actors * coreTypes.Actor , err error ) {
424426 return nil , fmt .Errorf ("not impl" )
425427}
0 commit comments