11package activation
22
33import (
4+ "bytes"
45 "context"
56 "errors"
67 "fmt"
78 "math"
9+ "sort"
810 "time"
911
1012 "github.com/spacemeshos/post/shared"
@@ -146,8 +148,6 @@ func (h *HandlerV2) processATX(
146148}
147149
148150// Syntactically validate an ATX.
149- // TODOs:
150- // 2. support merged ATXs.
151151func (h * HandlerV2 ) syntacticallyValidate (ctx context.Context , atx * wire.ActivationTxV2 ) error {
152152 if ! h .edVerifier .Verify (signing .ATX , atx .SmesherID , atx .SignedBytes (), atx .Signature ) {
153153 return fmt .Errorf ("invalid atx signature: %w" , errMalformedData )
@@ -219,8 +219,9 @@ func (h *HandlerV2) syntacticallyValidate(ctx context.Context, atx *wire.Activat
219219 if len (atx .Marriages ) != 0 {
220220 return errors .New ("merged atx cannot have marriages" )
221221 }
222- // TODO: support merged ATXs
223- return errors .New ("atx merge is not supported" )
222+ if err := h .verifyIncludedIDsUniqueness (atx ); err != nil {
223+ return err
224+ }
224225 default :
225226 // Solo chained (non-initial) ATX
226227 if len (atx .PreviousATXs ) != 1 {
@@ -353,13 +354,28 @@ func (h *HandlerV2) validatePreviousAtx(id types.NodeID, post *wire.SubPostV2, p
353354 }
354355 return min (prev .NumUnits , post .NumUnits ), nil
355356 case * wire.ActivationTxV2 :
356- // TODO: support previous merged-ATX
357-
358- // previous is solo ATX
359- if prev .SmesherID == id {
360- return min (prev .NiPosts [0 ].Posts [0 ].NumUnits , post .NumUnits ), nil
357+ if prev .MarriageATX != nil {
358+ // Previous is a merged ATX
359+ // need to find out if the given ID was present in the previous ATX
360+ _ , idx , err := identities .MarriageInfo (h .cdb , id )
361+ if err != nil {
362+ return 0 , fmt .Errorf ("fetching marriage info for ID %s: %w" , id , err )
363+ }
364+ for _ , nipost := range prev .NiPosts {
365+ for _ , post := range nipost .Posts {
366+ if post .MarriageIndex == uint32 (idx ) {
367+ return min (post .NumUnits , post .NumUnits ), nil
368+ }
369+ }
370+ }
371+ } else {
372+ // Previous is a solo ATX
373+ if prev .SmesherID == id {
374+ return min (prev .NiPosts [0 ].Posts [0 ].NumUnits , post .NumUnits ), nil
375+ }
361376 }
362- return 0 , fmt .Errorf ("previous solo ATX V2 has different owner: %s (expected %s)" , prev .SmesherID , id )
377+
378+ return 0 , fmt .Errorf ("previous ATX V2 doesn't contain %s" , id )
363379 }
364380 return 0 , fmt .Errorf ("unexpected previous ATX type: %T" , prev )
365381}
@@ -398,11 +414,56 @@ func (h *HandlerV2) validatePositioningAtx(publish types.EpochID, golden, positi
398414 return posAtx .TickHeight (), nil
399415}
400416
417+ // Validate marriage ATX and return the full equivocation set.
418+ func (h * HandlerV2 ) validateMarriages (atx * wire.ActivationTxV2 ) ([]types.NodeID , error ) {
419+ if atx .MarriageATX == nil {
420+ return []types.NodeID {atx .SmesherID }, nil
421+ }
422+ marriageAtxID , _ , err := identities .MarriageInfo (h .cdb , atx .SmesherID )
423+ switch {
424+ case errors .Is (err , sql .ErrNotFound ) || marriageAtxID == nil :
425+ return nil , errors .New ("smesher is not married" )
426+ case err != nil :
427+ return nil , fmt .Errorf ("fetching smesher's marriage atx ID: %w" , err )
428+ }
429+
430+ if * atx .MarriageATX != * marriageAtxID {
431+ return nil , fmt .Errorf ("smesher's marriage ATX ID mismatch: %s != %s" , * atx .MarriageATX , * marriageAtxID )
432+ }
433+
434+ marriageAtx , err := atxs .Get (h .cdb , * atx .MarriageATX )
435+ if err != nil {
436+ return nil , fmt .Errorf ("fetching marriage atx: %w" , err )
437+ }
438+ if ! (marriageAtx .PublishEpoch <= atx .PublishEpoch - 2 ) {
439+ return nil , fmt .Errorf (
440+ "marriage atx must be published at least 2 epochs before %v (is %v)" ,
441+ atx .PublishEpoch ,
442+ marriageAtx .PublishEpoch ,
443+ )
444+ }
445+
446+ return identities .EquivocationSetByMarriageATX (h .cdb , * atx .MarriageATX )
447+ }
448+
401449type atxParts struct {
402450 leaves uint64
403451 effectiveUnits uint32
404452}
405453
454+ func (h * HandlerV2 ) verifyIncludedIDsUniqueness (atx * wire.ActivationTxV2 ) error {
455+ seen := make (map [uint32 ]struct {})
456+ for _ , niposts := range atx .NiPosts {
457+ for _ , post := range niposts .Posts {
458+ if _ , ok := seen [post .MarriageIndex ]; ok {
459+ return fmt .Errorf ("ID present twice (duplicated marriage index): %d" , post .MarriageIndex )
460+ }
461+ seen [post .MarriageIndex ] = struct {}{}
462+ }
463+ }
464+ return nil
465+ }
466+
406467// Syntactically validate the ATX with its dependencies.
407468func (h * HandlerV2 ) syntacticallyValidateDeps (
408469 ctx context.Context ,
@@ -427,33 +488,42 @@ func (h *HandlerV2) syntacticallyValidateDeps(
427488 previousAtxs [i ] = prevAtx
428489 }
429490
430- // validate all niposts
431- // TODO: support merged ATXs
432- // For a merged ATX we need to fetch the equivocation this smesher is part of.
433- equivocationSet := []types.NodeID {atx .SmesherID }
491+ equivocationSet , err := h .validateMarriages (atx )
492+ if err != nil {
493+ return nil , nil , fmt .Errorf ("validating marriages: %w" , err )
494+ }
495+
496+ // validate previous ATXs
434497 var totalEffectiveNumUnits uint32
435- var minLeaves uint64 = math .MaxUint64
436- var smesherCommitment * types.ATXID
437498 for _ , niposts := range atx .NiPosts {
438- // verify PoET memberships in a single go
439- var poetChallenges [][]byte
440-
441499 for _ , post := range niposts .Posts {
442500 if post .MarriageIndex >= uint32 (len (equivocationSet )) {
443501 err := fmt .Errorf ("marriage index out of bounds: %d > %d" , post .MarriageIndex , len (equivocationSet )- 1 )
444502 return nil , nil , err
445503 }
504+
446505 id := equivocationSet [post .MarriageIndex ]
447506 effectiveNumUnits := post .NumUnits
448507 if atx .Initial == nil {
449508 var err error
450509 effectiveNumUnits , err = h .validatePreviousAtx (id , & post , previousAtxs )
451510 if err != nil {
452- return nil , nil , fmt .Errorf ("validating previous atx for ID %s : %w" , id , err )
511+ return nil , nil , fmt .Errorf ("validating previous atx: %w" , err )
453512 }
454513 }
455514 totalEffectiveNumUnits += effectiveNumUnits
515+ }
516+ }
456517
518+ // validate all niposts
519+ var minLeaves uint64 = math .MaxUint64
520+ var smesherCommitment * types.ATXID
521+ for _ , niposts := range atx .NiPosts {
522+ // verify PoET memberships in a single go
523+ var poetChallenges [][]byte
524+
525+ for _ , post := range niposts .Posts {
526+ id := equivocationSet [post .MarriageIndex ]
457527 var commitment types.ATXID
458528 if atx .Initial != nil {
459529 commitment = atx .Initial .CommitmentATX
@@ -463,7 +533,7 @@ func (h *HandlerV2) syntacticallyValidateDeps(
463533 if err != nil {
464534 return nil , nil , fmt .Errorf ("commitment atx not found for ID %s: %w" , id , err )
465535 }
466- if smesherCommitment == nil {
536+ if id == atx . SmesherID {
467537 smesherCommitment = & commitment
468538 }
469539 }
@@ -506,6 +576,9 @@ func (h *HandlerV2) syntacticallyValidateDeps(
506576 Nodes : niposts .Membership .Nodes ,
507577 LeafIndices : niposts .Membership .LeafIndices ,
508578 }
579+ sort .Slice (poetChallenges , func (i , j int ) bool {
580+ return bytes .Compare (poetChallenges [i ], poetChallenges [j ]) < 0
581+ })
509582 leaves , err := h .nipostValidator .PoetMembership (ctx , & membership , niposts .Challenge , poetChallenges )
510583 if err != nil {
511584 return nil , nil , fmt .Errorf ("invalid poet membership: %w" , err )
@@ -519,6 +592,9 @@ func (h *HandlerV2) syntacticallyValidateDeps(
519592 }
520593
521594 if atx .Initial == nil {
595+ if smesherCommitment == nil {
596+ return nil , nil , errors .New ("ATX signer not present in merged ATX" )
597+ }
522598 err := h .nipostValidator .VRFNonceV2 (atx .SmesherID , * smesherCommitment , atx .VRFNonce , atx .TotalNumUnits ())
523599 if err != nil {
524600 return nil , nil , fmt .Errorf ("validating VRF nonce: %w" , err )
@@ -563,6 +639,10 @@ func (h *HandlerV2) checkDoubleMarry(
563639 tx * sql.Tx ,
564640 watx * wire.ActivationTxV2 ,
565641) (* mwire.MalfeasanceProof , error ) {
642+ if len (watx .Marriages ) == 0 {
643+ // not trying to marry
644+ return nil , nil
645+ }
566646 checkMarried := func (tx * sql.Tx , id types.NodeID ) (* mwire.MalfeasanceProof , error ) {
567647 married , err := identities .Married (tx , id )
568648 if err != nil {
@@ -611,12 +691,12 @@ func (h *HandlerV2) storeAtx(
611691 }
612692
613693 if len (watx .Marriages ) != 0 {
614- for _ , m := range watx .Marriages {
615- if err := identities .SetMarriage (tx , m .ID , atx .ID ()); err != nil {
694+ for i , m := range watx .Marriages {
695+ if err := identities .SetMarriage (tx , m .ID , atx .ID (), i + 1 ); err != nil {
616696 return err
617697 }
618698 }
619- if err := identities .SetMarriage (tx , atx .SmesherID , atx .ID ()); err != nil {
699+ if err := identities .SetMarriage (tx , atx .SmesherID , atx .ID (), 0 ); err != nil {
620700 return err
621701 }
622702 if ! malicious && proof == nil {
0 commit comments