Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ _testmain.go
*.exe
*.test
*.prof
mrseq
101 changes: 101 additions & 0 deletions internal/seqno/seqno.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package seqno

import (
"github.com/go-kit/log"

"github.com/Azure/applicationhealth-extension-linux/pkg/logging"
"github.com/Azure/azure-extension-platform/pkg/extensionerrors"
"github.com/Azure/azure-extension-platform/pkg/seqno"
)

type SequenceNumberManager interface {
// GetCurrentSequenceNumber returns the current sequence number the extension is using
GetCurrentSequenceNumber(name, version string) (int, error)

// GetSequenceNumber retrieves the sequence number from the MRSEQ file
GetSequenceNumber(name, version string) (int, error)

// SetSequenceNumber sets the sequence number to the MRSEQ file.
SetSequenceNumber(name, version string, seqNo int) error

// FindSeqNum returns the requested the sequence number from either the environment variable or
// the most recently used file under the config folder.
// Note that this is different than just choosing the highest number, which may be incorrect
FindSeqNum(configFolder string) (int, error)
}

type SeqNumManager struct {
GetCurrentSequenceNumberFunc func(lg log.Logger, name, version string) (int, error)
GetSequenceNumberFunc func(name, version string) (int, error)
SequenceNumberSetterFunc func(name, version string, seqNo int) error
FindSeqNumFunc func(configFolder string) (int, error)
}

func (s *SeqNumManager) GetCurrentSequenceNumber(name, version string) (int, error) {
lg := logging.NewNopLogger()
return s.GetCurrentSequenceNumberFunc(lg, name, version)
}

func (s *SeqNumManager) GetSequenceNumber(name string, version string) (int, error) {
return s.GetSequenceNumberFunc(name, version)
}

func (s *SeqNumManager) SetSequenceNumber(name, version string, seqNo int) error {
return s.SequenceNumberSetterFunc(name, version, seqNo)
}

func (s *SeqNumManager) FindSeqNum(configFolder string) (int, error) {
return s.FindSeqNumFunc(configFolder)
}

func New() SequenceNumberManager {
return &SeqNumManager{
GetCurrentSequenceNumberFunc: GetCurrentSequenceNumberFunc(GetSequenceNumberFunc),
GetSequenceNumberFunc: GetSequenceNumberFunc,
SequenceNumberSetterFunc: SetSequenceNumber,
FindSeqNumFunc: FindSeqNum,
}
}

func GetSequenceNumberFunc(name, version string) (int, error) {
retriever := &seqno.ProdSequenceNumberRetriever{}
seqNum, err := retriever.GetSequenceNumber(name, version)
return int(seqNum), err
}

// SetSequenceNumber sets the sequence number for the given extension name and version.
// It takes the extension name, extension version, and sequence number as parameters.
// The sequence number is an integer that represents the order in which the extension was installed.
// It returns an error if there was a problem setting the sequence number.
func SetSequenceNumber(extName, extVersion string, seqNo int) error {
return seqno.SetSequenceNumber(extName, extVersion, uint(seqNo))
}

// FindSeqNum finds the sequence number for the given config folder.
// It returns the sequence number as an integer and any error encountered.
func FindSeqNum(configFolder string) (int, error) {
seqNum, err := seqno.FindSeqNum(logging.NewNopLogger(), configFolder)
if err != nil {
return 0, err
}
return int(seqNum), nil
}

func GetCurrentSequenceNumberFunc(getSequenceNumberFunc func(name, version string) (int, error)) func(lg log.Logger, name, version string) (int, error) {
return func(lg log.Logger, name, version string) (int, error) {
return getCurrentSequenceNumber(lg, getSequenceNumberFunc, name, version)
}
}

// GetCurrentSequenceNumber returns the current sequence number the extension is using
func getCurrentSequenceNumber(lg log.Logger, getSequenceNumberFunc func(name, version string) (int, error), name, version string) (int, error) {
sequenceNumber, err := getSequenceNumberFunc(name, version)
if err == extensionerrors.ErrNotFound || err == extensionerrors.ErrNoMrseqFile {
// If we can't find the sequence number, then it's possible that the extension
// hasn't been installed yet. Go back to 0.
lg.Log("event", "Couldn't find current sequence number, likely first execution of the extension, returning sequence number 0")
return 0, nil
}

return int(sequenceNumber), err
}
24 changes: 23 additions & 1 deletion main/cmds.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ const (

var (
cmdInstall = cmd{install, "Install", false, nil, 52}
cmdEnable = cmd{enable, "Enable", true, nil, 3}
cmdEnable = cmd{enable, "Enable", true, enablePre, 3}
cmdUninstall = cmd{uninstall, "Uninstall", false, nil, 3}

cmds = map[string]cmd{
Expand Down Expand Up @@ -77,6 +77,28 @@ var (
errTerminated = errors.New("Application health process terminated")
)

func enablePre(lg log.Logger, seqNum int) error {
// exit if this sequence number (a snapshot of the configuration) is already
// processed. if not, save this sequence number before proceeding.

mrSeqNum, err := seqnoManager.GetCurrentSequenceNumber(fullName, "")
if err != nil {
return errors.Wrap(err, "failed to get current sequence number")
}
// If the most recent sequence number is greater than or equal to the requested sequence number,
// then the script has already been run and we should exit.
if mrSeqNum != 0 && seqNum <= mrSeqNum {
lg.Log("event", "exit", "message", "the script configuration has already been processed, will not run again")
return errors.Errorf("most recent sequence number %d is greater than or equal to requested sequence number %d", mrSeqNum, seqNum)
}

// save the sequence number
if err := seqnoManager.SetSequenceNumber(fullName, "", seqNum); err != nil {
return errors.Wrap(err, "failed to save sequence number")
}
return nil
}

func enable(lg log.Logger, h *handlerenv.HandlerEnvironment, seqNum int) (string, error) {
// parse the extension handler settings (not available prior to 'enable')
cfg, err := parseAndValidateSettings(lg, h.ConfigFolder)
Expand Down
102 changes: 102 additions & 0 deletions main/cmds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ package main
import (
"testing"

"github.com/Azure/applicationhealth-extension-linux/internal/seqno"
"github.com/Azure/azure-extension-platform/pkg/extensionerrors"
"github.com/go-kit/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand All @@ -29,3 +33,101 @@ func Test_commands_shouldReportStatus(t *testing.T) {
require.True(t, cmds["disable"].shouldReportStatus, "disable should report status")
require.True(t, cmds["update"].shouldReportStatus, "update should report status")
}

func Test_enablePre(t *testing.T) {
var (
logger = log.NewNopLogger()
seqNum = 5
)

t.Run("SequenceNumberAlreadyProcessed", func(t *testing.T) {
mockSeqNumManager := &seqno.SeqNumManager{
GetSequenceNumberFunc: seqno.GetSequenceNumberFunc,
SequenceNumberSetterFunc: seqno.SetSequenceNumber,
FindSeqNumFunc: seqno.FindSeqNum,
GetCurrentSequenceNumberFunc: func(lg log.Logger, name, version string) (int, error) {
return 5, nil
},
}
seqnoManager = mockSeqNumManager
err := enablePre(logger, seqNum)
assert.Error(t, err)
assert.EqualError(t, err, "most recent sequence number 5 is greater than or equal to requested sequence number 5")
})

t.Run("SaveSequenceNumberError_ShouldFail", func(t *testing.T) {
seqNum = 0
mockSeqNumManager := &seqno.SeqNumManager{
GetSequenceNumberFunc: seqno.GetSequenceNumberFunc,
SequenceNumberSetterFunc: seqno.SetSequenceNumber,
FindSeqNumFunc: seqno.FindSeqNum,
GetCurrentSequenceNumberFunc: func(lg log.Logger, name, version string) (int, error) {
return 1, nil
},
}
seqnoManager = mockSeqNumManager
err := enablePre(logger, seqNum)
assert.Error(t, err)
assert.EqualError(t, err, "most recent sequence number 1 is greater than or equal to requested sequence number 0")
})

t.Run("SequenceNumberisZero_ShouldPass", func(t *testing.T) {
seqNum = 0
mockSeqNumManager := &seqno.SeqNumManager{
GetSequenceNumberFunc: seqno.GetSequenceNumberFunc,
SequenceNumberSetterFunc: seqno.SetSequenceNumber,
FindSeqNumFunc: seqno.FindSeqNum,
GetCurrentSequenceNumberFunc: func(lg log.Logger, name, version string) (int, error) {
return 0, nil
},
}
seqnoManager = mockSeqNumManager
err := enablePre(logger, seqNum)
assert.NoError(t, err)
})
t.Run("MrSeqFileNotFound_ShouldPass", func(t *testing.T) {
seqNum = 0
mockGetSequenceNumberFunc := func(name, version string) (int, error) {
return 0, extensionerrors.ErrNoMrseqFile
}
mockSeqNumManager := &seqno.SeqNumManager{
GetSequenceNumberFunc: mockGetSequenceNumberFunc,
SequenceNumberSetterFunc: seqno.SetSequenceNumber,
FindSeqNumFunc: seqno.FindSeqNum,
GetCurrentSequenceNumberFunc: seqno.GetCurrentSequenceNumberFunc(mockGetSequenceNumberFunc),
}
seqnoManager = mockSeqNumManager
err := enablePre(logger, seqNum)
assert.NoError(t, err)
})
t.Run("GetSequenceNumberIsGreaterThanRequestedSequenceNumber_ShouldFail", func(t *testing.T) {
seqNum = 4
mockSeqNumManager := &seqno.SeqNumManager{
GetSequenceNumberFunc: seqno.GetSequenceNumberFunc,
SequenceNumberSetterFunc: seqno.SetSequenceNumber,
FindSeqNumFunc: seqno.FindSeqNum,
GetCurrentSequenceNumberFunc: func(lg log.Logger, name, version string) (int, error) {
return 8, nil
},
}
seqnoManager = mockSeqNumManager
err := enablePre(logger, seqNum)
assert.Error(t, err)
assert.EqualError(t, err, "most recent sequence number 8 is greater than or equal to requested sequence number 4")
})
t.Run("GetSequenceNumberIsEqualRequestedSequenceNumber_ShouldFail", func(t *testing.T) {
seqNum = 4
mockSeqNumManager := &seqno.SeqNumManager{
GetSequenceNumberFunc: seqno.GetSequenceNumberFunc,
SequenceNumberSetterFunc: seqno.SetSequenceNumber,
FindSeqNumFunc: seqno.FindSeqNum,
GetCurrentSequenceNumberFunc: func(lg log.Logger, name, version string) (int, error) {
return 4, nil
},
}
seqnoManager = mockSeqNumManager
err := enablePre(logger, seqNum)
assert.Error(t, err)
assert.EqualError(t, err, "most recent sequence number 4 is greater than or equal to requested sequence number 4")
})
}
6 changes: 5 additions & 1 deletion main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"syscall"

"github.com/Azure/applicationhealth-extension-linux/internal/handlerenv"
"github.com/Azure/applicationhealth-extension-linux/internal/seqno"
"github.com/Azure/applicationhealth-extension-linux/internal/telemetry"
"github.com/Azure/applicationhealth-extension-linux/pkg/logging"
"github.com/Azure/azure-extension-platform/pkg/extensionevents"
Expand All @@ -28,6 +29,8 @@ var (
eem *extensionevents.ExtensionEventManager

sendTelemetry telemetry.LogEventFunc

seqnoManager seqno.SequenceNumberManager = seqno.New()
)

func main() {
Expand Down Expand Up @@ -61,7 +64,8 @@ func main() {
logger.Log("message", "failed to parse handlerenv", "error", err)
os.Exit(cmd.failExitCode)
}
seqNum, err := FindSeqNum(hEnv.ConfigFolder)

seqNum, err := seqnoManager.FindSeqNum(hEnv.ConfigFolder)
if err != nil {
logger.Log("message", "failed to find sequence number", "error", err)
}
Expand Down
32 changes: 0 additions & 32 deletions main/seqnum.go

This file was deleted.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading