diff --git a/.gitignore b/.gitignore index bd7a535..dfa2a7e 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ _testmain.go *.exe *.test *.prof +mrseq diff --git a/go.mod b/go.mod index d677d5c..b020ca8 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,10 @@ require ( github.com/xeipuuv/gojsonschema v0.0.0-20160623135812-c539bca196be ) -require github.com/go-kit/log v0.2.0 +require ( + github.com/go-kit/log v0.2.0 + go.uber.org/mock v0.4.0 +) require ( github.com/cilium/ebpf v0.9.1 // indirect diff --git a/go.sum b/go.sum index 6ef24ca..39dbb30 100644 --- a/go.sum +++ b/go.sum @@ -51,6 +51,8 @@ github.com/xeipuuv/gojsonschema v0.0.0-20160623135812-c539bca196be h1:sRGd3e18iz github.com/xeipuuv/gojsonschema v0.0.0-20160623135812-c539bca196be/go.mod h1:5yf86TLmAcydyeJq5YvxkGPE2fm/u4myDekKRoLuqhs= go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/seqno/mock_seqno.go b/internal/seqno/mock_seqno.go new file mode 100644 index 0000000..ee2fd22 --- /dev/null +++ b/internal/seqno/mock_seqno.go @@ -0,0 +1,99 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: internal/seqno/seqno.go +// +// Generated by this command: +// +// mockgen -source=internal/seqno/seqno.go -destination=internal/seqno/mock_seqno.go -package=seqno +// + +// Package seqno is a generated GoMock package. +package seqno + +import ( + reflect "reflect" + + log "github.com/go-kit/log" + gomock "go.uber.org/mock/gomock" +) + +// MockSequenceNumberManager is a mock of SequenceNumberManager interface. +type MockSequenceNumberManager struct { + ctrl *gomock.Controller + recorder *MockSequenceNumberManagerMockRecorder +} + +// MockSequenceNumberManagerMockRecorder is the mock recorder for MockSequenceNumberManager. +type MockSequenceNumberManagerMockRecorder struct { + mock *MockSequenceNumberManager +} + +// NewMockSequenceNumberManager creates a new mock instance. +func NewMockSequenceNumberManager(ctrl *gomock.Controller) *MockSequenceNumberManager { + mock := &MockSequenceNumberManager{ctrl: ctrl} + mock.recorder = &MockSequenceNumberManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSequenceNumberManager) EXPECT() *MockSequenceNumberManagerMockRecorder { + return m.recorder +} + +// FindSeqNum mocks base method. +func (m *MockSequenceNumberManager) FindSeqNum(configFolder string) (uint, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FindSeqNum", configFolder) + ret0, _ := ret[0].(uint) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FindSeqNum indicates an expected call of FindSeqNum. +func (mr *MockSequenceNumberManagerMockRecorder) FindSeqNum(configFolder any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindSeqNum", reflect.TypeOf((*MockSequenceNumberManager)(nil).FindSeqNum), configFolder) +} + +// GetCurrentSequenceNumber mocks base method. +func (m *MockSequenceNumberManager) GetCurrentSequenceNumber(el log.Logger, name, version string) (uint, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCurrentSequenceNumber", el, name, version) + ret0, _ := ret[0].(uint) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCurrentSequenceNumber indicates an expected call of GetCurrentSequenceNumber. +func (mr *MockSequenceNumberManagerMockRecorder) GetCurrentSequenceNumber(el, name, version any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCurrentSequenceNumber", reflect.TypeOf((*MockSequenceNumberManager)(nil).GetCurrentSequenceNumber), el, name, version) +} + +// GetSequenceNumber mocks base method. +func (m *MockSequenceNumberManager) GetSequenceNumber(name, version string) (uint, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSequenceNumber", name, version) + ret0, _ := ret[0].(uint) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSequenceNumber indicates an expected call of GetSequenceNumber. +func (mr *MockSequenceNumberManagerMockRecorder) GetSequenceNumber(name, version any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSequenceNumber", reflect.TypeOf((*MockSequenceNumberManager)(nil).GetSequenceNumber), name, version) +} + +// SetSequenceNumber mocks base method. +func (m *MockSequenceNumberManager) SetSequenceNumber(name, version string, seqNo uint) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetSequenceNumber", name, version, seqNo) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetSequenceNumber indicates an expected call of SetSequenceNumber. +func (mr *MockSequenceNumberManagerMockRecorder) SetSequenceNumber(name, version, seqNo any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSequenceNumber", reflect.TypeOf((*MockSequenceNumberManager)(nil).SetSequenceNumber), name, version, seqNo) +} diff --git a/internal/seqno/seqno.go b/internal/seqno/seqno.go new file mode 100644 index 0000000..ceec0dd --- /dev/null +++ b/internal/seqno/seqno.go @@ -0,0 +1,64 @@ +package seqno + +import ( + "github.com/Azure/applicationhealth-extension-linux/pkg/logging" + "github.com/Azure/azure-extension-platform/pkg/extensionerrors" + "github.com/Azure/azure-extension-platform/pkg/seqno" + "github.com/go-kit/log" +) + +type SequenceNumberManager interface { + // GetCurrentSequenceNumber returns the current sequence number the extension is using + GetCurrentSequenceNumber(el log.Logger, name, version string) (uint, error) + + // GetSequenceNumber retrieves the sequence number from the MRSEQ file + GetSequenceNumber(name, version string) (uint, error) + + // SetSequenceNumber sets the sequence number to the MRSEQ file. + SetSequenceNumber(name, version string, seqNo uint) error + + // FindSeqNum returns the requested the sequence number from either the environment variable or + // the most recently used file under the config folder. + // Note that this is different than just choosing the highest number, which may be incorrect + FindSeqNum(configFolder string) (uint, error) +} + +type SeqNumManager struct { +} + +func (s *SeqNumManager) GetSequenceNumber(name string, version string) (uint, error) { + retriever := &seqno.ProdSequenceNumberRetriever{} + return retriever.GetSequenceNumber(name, version) +} + +// SetSequenceNumber sets the sequence number for the given extension name and version. +// It takes the extension name, extension version, and sequence number as parameters. +// The sequence number is an integer that represents the order in which the extension was installed. +// It returns an error if there was a problem setting the sequence number. +func (s *SeqNumManager) SetSequenceNumber(name, version string, seqNo uint) error { + return seqno.SetSequenceNumber(name, version, seqNo) +} + +// FindSeqNum returns the requested the sequence number from either the environment variable or +// the most recently used file under the config folder. +// Note that this is different than just choosing the highest number, which may be incorrect +func (s *SeqNumManager) FindSeqNum(configFolder string) (uint, error) { + return seqno.FindSeqNum(logging.NewNopLogger(), configFolder) +} + +// GetCurrentSequenceNumber returns the current sequence number the extension is using +func (s *SeqNumManager) GetCurrentSequenceNumber(lg log.Logger, name, version string) (sn uint, _ error) { + sequenceNumber, err := s.GetSequenceNumber(name, version) + if err == extensionerrors.ErrNotFound || err == extensionerrors.ErrNoMrseqFile { + // If we can't find the sequence number, then it's possible that the extension + // hasn't been installed yet. Go back to 0. + lg.Log("event", "Couldn't find current sequence number, likely first execution of the extension, returning sequence number 0") + return 0, nil + } + + return sequenceNumber, err +} + +func New() SequenceNumberManager { + return &SeqNumManager{} +} diff --git a/main/cmds.go b/main/cmds.go index 89e9cbb..70734df 100644 --- a/main/cmds.go +++ b/main/cmds.go @@ -12,8 +12,8 @@ import ( "github.com/pkg/errors" ) -type cmdFunc func(lg log.Logger, hEnv *handlerenv.HandlerEnvironment, seqNum int) (msg string, err error) -type preFunc func(lg log.Logger, seqNum int) error +type cmdFunc func(lg log.Logger, hEnv *handlerenv.HandlerEnvironment, seqNum uint) (msg string, err error) +type preFunc func(lg log.Logger, seqNum uint) error type cmd struct { f cmdFunc // associated function @@ -29,7 +29,7 @@ const ( var ( cmdInstall = cmd{install, "Install", false, nil, 52} - cmdEnable = cmd{enable, "Enable", true, nil, 3} + cmdEnable = cmd{enable, "Enable", true, enablePre, 3} cmdUninstall = cmd{uninstall, "Uninstall", false, nil, 3} cmds = map[string]cmd{ @@ -41,12 +41,12 @@ var ( } ) -func noop(lg log.Logger, h *handlerenv.HandlerEnvironment, seqNum int) (string, error) { +func noop(lg log.Logger, h *handlerenv.HandlerEnvironment, seqNum uint) (string, error) { lg.Log("event", "noop") return "", nil } -func install(lg log.Logger, h *handlerenv.HandlerEnvironment, seqNum int) (string, error) { +func install(lg log.Logger, h *handlerenv.HandlerEnvironment, seqNum uint) (string, error) { if err := os.MkdirAll(dataDir, 0755); err != nil { return "", errors.Wrap(err, "failed to create data dir") } @@ -56,7 +56,7 @@ func install(lg log.Logger, h *handlerenv.HandlerEnvironment, seqNum int) (strin return "", nil } -func uninstall(lg log.Logger, h *handlerenv.HandlerEnvironment, seqNum int) (string, error) { +func uninstall(lg log.Logger, h *handlerenv.HandlerEnvironment, seqNum uint) (string, error) { { // a new context scope with path lg = log.With(lg, "path", dataDir) sendTelemetry(lg, telemetry.EventLevelInfo, telemetry.AppHealthTask, "Removing data dir") @@ -77,7 +77,29 @@ var ( errTerminated = errors.New("Application health process terminated") ) -func enable(lg log.Logger, h *handlerenv.HandlerEnvironment, seqNum int) (string, error) { +func enablePre(lg log.Logger, seqNum uint) error { + // exit if this sequence number (a snapshot of the configuration) is already + // processed. if not, save this sequence number before proceeding. + + mrSeqNum, err := seqnoManager.GetCurrentSequenceNumber(lg, fullName, "") + if err != nil { + return errors.Wrap(err, "failed to get current sequence number") + } + // If the most recent sequence number is greater than or equal to the requested sequence number, + // then the script has already been run and we should exit. + if mrSeqNum != 0 && seqNum < mrSeqNum { + lg.Log("event", "exit", "message", "the script configuration has already been processed, will not run again") + return errors.Errorf("most recent sequence number %d is greater than the requested sequence number %d", mrSeqNum, seqNum) + } + + // save the sequence number + if err := seqnoManager.SetSequenceNumber(fullName, "", seqNum); err != nil { + return errors.Wrap(err, "failed to save sequence number") + } + return nil +} + +func enable(lg log.Logger, h *handlerenv.HandlerEnvironment, seqNum uint) (string, error) { // parse the extension handler settings (not available prior to 'enable') cfg, err := parseAndValidateSettings(lg, h.ConfigFolder) if err != nil { diff --git a/main/cmds_test.go b/main/cmds_test.go index 473d7ab..e637793 100644 --- a/main/cmds_test.go +++ b/main/cmds_test.go @@ -3,7 +3,11 @@ package main import ( "testing" + "github.com/Azure/applicationhealth-extension-linux/internal/seqno" + "github.com/go-kit/log" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" ) func Test_commandsExist(t *testing.T) { @@ -29,3 +33,58 @@ func Test_commands_shouldReportStatus(t *testing.T) { require.True(t, cmds["disable"].shouldReportStatus, "disable should report status") require.True(t, cmds["update"].shouldReportStatus, "update should report status") } + +func Test_enablePre(t *testing.T) { + var ( + logger = log.NewNopLogger() + seqNumToProcess uint + ctrl = gomock.NewController(t) + ) + + mockSeqNumManager := seqno.NewMockSequenceNumberManager(ctrl) + t.Run("SaveSequenceNumberError_ShouldFail", func(t *testing.T) { + // seqNumToProcess = 0, mrSeqNum = 1 + seqNumToProcess = 0 + mockSeqNumManager.EXPECT().GetCurrentSequenceNumber(gomock.Any(), gomock.Any(), gomock.Any()).Return(uint(1), nil) + seqnoManager = mockSeqNumManager + err := enablePre(logger, seqNumToProcess) + assert.Error(t, err) + assert.EqualError(t, err, "most recent sequence number 1 is greater than the requested sequence number 0") + }) + t.Run("GetSequenceNumberIsGreaterThanRequestedSequenceNumber_ShouldFail", func(t *testing.T) { + // seqNumToProcess = 4, mrSeqNum = 8 + seqNumToProcess = 4 + mockSeqNumManager.EXPECT().GetCurrentSequenceNumber(gomock.Any(), gomock.Any(), gomock.Any()).Return(uint(8), nil) + seqnoManager = mockSeqNumManager + err := enablePre(logger, seqNumToProcess) + assert.Error(t, err) + assert.EqualError(t, err, "most recent sequence number 8 is greater than the requested sequence number 4") + }) + t.Run("SequenceNumberisZero_Startup", func(t *testing.T) { + // seqNumToProcess = 0, mrSeqNum = 0 + seqNumToProcess = 0 + mockSeqNumManager.EXPECT().GetCurrentSequenceNumber(gomock.Any(), gomock.Any(), gomock.Any()).Return(uint(0), nil) + mockSeqNumManager.EXPECT().SetSequenceNumber(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + seqnoManager = mockSeqNumManager + err := enablePre(logger, seqNumToProcess) + assert.NoError(t, err) + }) + t.Run("SequenceNumberAlreadyProcessed", func(t *testing.T) { + // seqNumToProcess = 5, mrSeqNum = 5 + seqNumToProcess = 5 + seqnoManager = mockSeqNumManager + mockSeqNumManager.EXPECT().GetCurrentSequenceNumber(gomock.Any(), gomock.Any(), gomock.Any()).Return(uint(5), nil) + mockSeqNumManager.EXPECT().SetSequenceNumber(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + err := enablePre(logger, seqNumToProcess) + assert.NoError(t, err) + }) + t.Run("MostRecentSeqNumIsSmaller_ShouldPass", func(t *testing.T) { + // seqNumToProcess = 4, mrSeqNum = 2 + seqNumToProcess = 4 + mockSeqNumManager.EXPECT().GetCurrentSequenceNumber(gomock.Any(), gomock.Any(), gomock.Any()).Return(uint(2), nil) + mockSeqNumManager.EXPECT().SetSequenceNumber(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + seqnoManager = mockSeqNumManager + err := enablePre(logger, seqNumToProcess) + assert.NoError(t, err) + }) +} diff --git a/main/main.go b/main/main.go index e209d77..df4b743 100644 --- a/main/main.go +++ b/main/main.go @@ -9,6 +9,7 @@ import ( "syscall" "github.com/Azure/applicationhealth-extension-linux/internal/handlerenv" + "github.com/Azure/applicationhealth-extension-linux/internal/seqno" "github.com/Azure/applicationhealth-extension-linux/internal/telemetry" "github.com/Azure/applicationhealth-extension-linux/pkg/logging" "github.com/Azure/azure-extension-platform/pkg/extensionevents" @@ -28,6 +29,8 @@ var ( eem *extensionevents.ExtensionEventManager sendTelemetry telemetry.LogEventFunc + + seqnoManager seqno.SequenceNumberManager = seqno.New() ) func main() { @@ -61,7 +64,8 @@ func main() { logger.Log("message", "failed to parse handlerenv", "error", err) os.Exit(cmd.failExitCode) } - seqNum, err := FindSeqNum(hEnv.ConfigFolder) + + seqNum, err := seqnoManager.FindSeqNum(hEnv.ConfigFolder) if err != nil { logger.Log("message", "failed to find sequence number", "error", err) } diff --git a/main/reportstatus.go b/main/reportstatus.go index ca16693..46bd8a1 100644 --- a/main/reportstatus.go +++ b/main/reportstatus.go @@ -14,7 +14,7 @@ import ( // status. // // If an error occurs reporting the status, it will be logged and returned. -func reportStatus(lg log.Logger, hEnv *handlerenv.HandlerEnvironment, seqNum int, t StatusType, c cmd, msg string) error { +func reportStatus(lg log.Logger, hEnv *handlerenv.HandlerEnvironment, seqNum uint, t StatusType, c cmd, msg string) error { if !c.shouldReportStatus { lg.Log("status", "not reported for operation (by design)") return nil @@ -29,7 +29,7 @@ func reportStatus(lg log.Logger, hEnv *handlerenv.HandlerEnvironment, seqNum int return nil } -func reportStatusWithSubstatuses(lg log.Logger, hEnv *handlerenv.HandlerEnvironment, seqNum int, t StatusType, op string, msg string, substatuses []SubstatusItem) error { +func reportStatusWithSubstatuses(lg log.Logger, hEnv *handlerenv.HandlerEnvironment, seqNum uint, t StatusType, op string, msg string, substatuses []SubstatusItem) error { s := NewStatus(t, op, msg) for _, substatus := range substatuses { s.AddSubstatusItem(substatus) diff --git a/main/seqnum.go b/main/seqnum.go deleted file mode 100644 index 8b2d646..0000000 --- a/main/seqnum.go +++ /dev/null @@ -1,32 +0,0 @@ -package main - -import ( - "fmt" - "path/filepath" - "sort" - "strconv" - "strings" -) - -// FindSeqnum finds the file with the highest number under configFolder -// named like 0.settings, 1.settings so on. -func FindSeqNum(configFolder string) (int, error) { - g, err := filepath.Glob(configFolder + "/*.settings") - if err != nil { - return 0, err - } - seqs := make([]int, len(g)) - for _, v := range g { - f := filepath.Base(v) - i, err := strconv.Atoi(strings.Replace(f, ".settings", "", 1)) - if err != nil { - return 0, fmt.Errorf("Can't parse int from filename: %s", f) - } - seqs = append(seqs, i) - } - if len(seqs) == 0 { - return 0, fmt.Errorf("Can't find out seqnum from %s, not enough files.", configFolder) - } - sort.Sort(sort.Reverse(sort.IntSlice(seqs))) - return seqs[0], nil -} diff --git a/main/status.go b/main/status.go index c22faa9..c668b61 100644 --- a/main/status.go +++ b/main/status.go @@ -102,7 +102,7 @@ func (r StatusReport) marshal() ([]byte, error) { // Save persists the status message to the specified status folder using the // sequence number. The operation consists of writing to a temporary file in the // same folder and moving it to the final destination for atomicity. -func (r StatusReport) Save(statusFolder string, seqNum int) error { +func (r StatusReport) Save(statusFolder string, seqNum uint) error { fn := fmt.Sprintf("%d.status", seqNum) path := filepath.Join(statusFolder, fn) tmpFile, err := ioutil.TempFile(statusFolder, fn) diff --git a/vendor/github.com/Azure/azure-extension-platform/pkg/constants/constants.go b/vendor/github.com/Azure/azure-extension-platform/pkg/constants/constants.go new file mode 100644 index 0000000..f8e2a11 --- /dev/null +++ b/vendor/github.com/Azure/azure-extension-platform/pkg/constants/constants.go @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +package constants + +const ( + FilePermissions_UserOnly_ReadWrite = 0600 + FilePermissions_UserOnly_ReadWriteExecute = 0700 +) diff --git a/vendor/github.com/Azure/azure-extension-platform/pkg/constants/constants_linux.go b/vendor/github.com/Azure/azure-extension-platform/pkg/constants/constants_linux.go new file mode 100644 index 0000000..0464d56 --- /dev/null +++ b/vendor/github.com/Azure/azure-extension-platform/pkg/constants/constants_linux.go @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +package constants + +const ( + NewLineCharacter = "\n" +) diff --git a/vendor/github.com/Azure/azure-extension-platform/pkg/constants/constants_windows.go b/vendor/github.com/Azure/azure-extension-platform/pkg/constants/constants_windows.go new file mode 100644 index 0000000..bde4a9d --- /dev/null +++ b/vendor/github.com/Azure/azure-extension-platform/pkg/constants/constants_windows.go @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +package constants + +const NewLineCharacter = "\r\n" diff --git a/vendor/github.com/Azure/azure-extension-platform/pkg/seqno/seqno.go b/vendor/github.com/Azure/azure-extension-platform/pkg/seqno/seqno.go new file mode 100644 index 0000000..7cdf57c --- /dev/null +++ b/vendor/github.com/Azure/azure-extension-platform/pkg/seqno/seqno.go @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +package seqno + +import ( + "io/ioutil" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "time" + + "github.com/Azure/azure-extension-platform/pkg/extensionerrors" + "github.com/Azure/azure-extension-platform/pkg/logging" +) + +const configSequenceNumber = "ConfigSequenceNumber" + +type ISequenceNumberRetriever interface { + GetSequenceNumber(name, version string) (uint, error) +} + +type ProdSequenceNumberRetriever struct { +} + +func (*ProdSequenceNumberRetriever) GetSequenceNumber(name, version string) (uint, error) { + return getSequenceNumberInternal(name, version) +} + +// GetCurrentSequenceNumber returns the current sequence number the extension is using +func GetCurrentSequenceNumber(el logging.ILogger, retriever ISequenceNumberRetriever, name, version string) (sn uint, _ error) { + sequenceNumber, err := retriever.GetSequenceNumber(name, version) + if err == extensionerrors.ErrNotFound || err == extensionerrors.ErrNoMrseqFile { + // If we can't find the sequence number, then it's possible that the extension + // hasn't been installed yet. Go back to 0. + el.Info("Couldn't find current sequence number, likely first execution of the extension, returning sequence number 0") + return 0, nil + } + + return sequenceNumber, err +} + +func SetSequenceNumber(extName, extVersion string, seqNo uint) error { + return setSequenceNumberInternal(extName, extVersion, seqNo) +} + +// findSeqnum finds the most recently used file under the config folder +// Note that this is different than just choosing the highest number, which may be incorrect +func FindSeqNum(el logging.ILogger, configFolder string) (uint, error) { + // try getting the sequence number from the environment first + seqNoString := os.Getenv(configSequenceNumber) + if seqNoString == "" { + el.Info("could not read environment variable %s for getting sequence number", configSequenceNumber) + } else { + seqNo, err := strconv.ParseUint(seqNoString, 10, 64) + if err != nil { + el.Info("could not read sequence number string %s into unsigned integer", seqNoString) + } else { + el.Info("using sequence number %d from environment variable %s", seqNo, configSequenceNumber) + return uint(seqNo), nil + } + } + + g, err := filepath.Glob(configFolder + "/*.settings") + if err != nil { + return 0, err + } + + // Start by finding the file with the latest time + files, err := ioutil.ReadDir(configFolder) + if err != nil { + return 0, err + } + + var modTime time.Time + var names []string + for _, fi := range files { + if fi.Mode().IsRegular() && strings.HasSuffix(fi.Name(), ".settings") { + if !fi.ModTime().Before(modTime) { + if fi.ModTime().After(modTime) { + modTime = fi.ModTime() + names = names[:0] + } + + // Unlikely - but just in case the files have the same time + names = append(names, fi.Name()) + } + } + } + + if len(names) == 0 { + el.Error("Cannot find the seqNo from %s. Not enough files", configFolder) + return 0, extensionerrors.ErrNoSettingsFiles + } else if len(names) == 1 { + i, err := strconv.Atoi(strings.Replace(names[0], ".settings", "", 1)) + if err != nil { + el.Error("Can't parse int from filename: %s", names[0]) + return 0, extensionerrors.ErrInvalidSettingsFileName + } + + return uint(i), nil + } else { + // For some reason we have two or more files with the same time stamp. + // Revert to choosing the highest number. + seqs := make([]int, len(g)) + for _, v := range g { + f := filepath.Base(v) + i, err := strconv.Atoi(strings.Replace(f, ".settings", "", 1)) + if err != nil { + el.Error("Can't parse int from filename: %s", f) + return 0, extensionerrors.ErrInvalidSettingsFileName + } + seqs = append(seqs, i) + } + + sort.Sort(sort.Reverse(sort.IntSlice(seqs))) + return uint(seqs[0]), nil + } +} diff --git a/vendor/github.com/Azure/azure-extension-platform/pkg/seqno/seqno_linux.go b/vendor/github.com/Azure/azure-extension-platform/pkg/seqno/seqno_linux.go new file mode 100644 index 0000000..876bf76 --- /dev/null +++ b/vendor/github.com/Azure/azure-extension-platform/pkg/seqno/seqno_linux.go @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +package seqno + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strconv" + + "github.com/Azure/azure-extension-platform/pkg/constants" + "github.com/Azure/azure-extension-platform/pkg/extensionerrors" + "github.com/Azure/azure-extension-platform/pkg/utils" +) + +var mostRecentSequenceFileName = "mrseq" + +// sequence number for the extension from the registry +func getSequenceNumberInternal(name, version string) (uint, error) { + mrseqPath, err := getMrseqFilePath() + if err != nil { + return 0, err + } + mrseqStr, err := ioutil.ReadFile(mrseqPath) + if err != nil { + if os.IsNotExist(err) { + return 0, extensionerrors.ErrNoMrseqFile + } + return 0, fmt.Errorf("failed to read mrseq file : %s", err) + } + + seqNum, err := strconv.Atoi(string(mrseqStr)) + if err != nil { + return 0, err + } + return uint(seqNum), nil + +} + +func setSequenceNumberInternal(extName, extVersion string, seqNo uint) error { + b := []byte(fmt.Sprintf("%v", seqNo)) + + mrseqPath, err := getMrseqFilePath() + if err != nil { + return err + } + err = ioutil.WriteFile(mrseqPath, b, constants.FilePermissions_UserOnly_ReadWrite) + if err != nil { + return fmt.Errorf("could not write sequence number file %s, error: %v", mostRecentSequenceFileName, err) + } + return nil +} + +func getMrseqFilePath() (string, error) { + // mrseq file path must always be present in the same directory as the extension executable + currentDir, err := utils.GetCurrentProcessWorkingDir() + if err != nil { + return "", err + } + return filepath.Join(currentDir, mostRecentSequenceFileName), nil +} diff --git a/vendor/github.com/Azure/azure-extension-platform/pkg/seqno/seqno_windows.go b/vendor/github.com/Azure/azure-extension-platform/pkg/seqno/seqno_windows.go new file mode 100644 index 0000000..470360b --- /dev/null +++ b/vendor/github.com/Azure/azure-extension-platform/pkg/seqno/seqno_windows.go @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +package seqno + +import ( + "fmt" + "github.com/Azure/azure-extension-platform/pkg/extensionerrors" + "golang.org/x/sys/windows/registry" + "path" + "strconv" +) + +const ( + sequenceNumberKeyName = "SequenceNumber" +) + +// getSequenceNumberInternal is the Windows specific logic for reading the current +// sequence number for the extension from the registry +func getSequenceNumberInternal(name, version string) (uint, error) { + extensionKeyName := getExtensionKeyName(name, version) + k, err := registry.OpenKey(registry.LOCAL_MACHINE, extensionKeyName, registry.QUERY_VALUE) + if err != nil { + if err == registry.ErrNotExist { + // This may happen if the extension isn't installed. Return a uniform error indicating this. + return 0, extensionerrors.ErrNotFound + } + + return 0, fmt.Errorf("VmExtension: Cannot open sequence registry key due to '%v'", err) + } + defer k.Close() + + buff := make([]byte, 32) + + _, valType, err := k.GetValue(sequenceNumberKeyName, buff) + if err != nil { + if err == registry.ErrNotExist { + return 0, extensionerrors.ErrNotFound + } + return 0, fmt.Errorf("VmExtension: Cannot read sequence registry key due to '%v'", err) + } + var value uint + switch valType { + case registry.SZ, registry.EXPAND_SZ: + stringVal, _, err := k.GetStringValue(sequenceNumberKeyName) + if err != nil { + return 0, err + } + val, err := strconv.ParseUint(stringVal, 10, 64) + if err != nil { + return 0, err + } + value = uint(val) + case registry.DWORD: + val, _, err := k.GetIntegerValue(sequenceNumberKeyName) + if err != nil { + return 0, err + } + value = uint(val) + default: + return 0, fmt.Errorf("value of registry key %s is of unexpected type", path.Join("HKEY_LOCAL_MACHINE", extensionKeyName, sequenceNumberKeyName)) + } + + return uint(value), nil +} + +func getExtensionKeyName(name string, version string) (keyName string) { + return fmt.Sprintf("Software\\Microsoft\\Windows Azure\\HandlerState\\%s_%s", name, version) +} + +// setSequenceNumberInternal writes the sequence number for the extension to the registry +func setSequenceNumberInternal(extName, extVersion string, seqNo uint) error { + extensionKeyName := getExtensionKeyName(extName, extVersion) + k, err := registry.OpenKey(registry.LOCAL_MACHINE, extensionKeyName, registry.WRITE) + if err != nil { + return fmt.Errorf("VmExtension: Cannot write sequence registry key due to '%v'", err) + } + defer k.Close() + + err = k.SetDWordValue(sequenceNumberKeyName, uint32(seqNo)) + return err +} diff --git a/vendor/go.uber.org/mock/AUTHORS b/vendor/go.uber.org/mock/AUTHORS new file mode 100644 index 0000000..660b8cc --- /dev/null +++ b/vendor/go.uber.org/mock/AUTHORS @@ -0,0 +1,12 @@ +# This is the official list of GoMock authors for copyright purposes. +# This file is distinct from the CONTRIBUTORS files. +# See the latter for an explanation. + +# Names should be added to this file as +# Name or Organization +# The email address is not required for organizations. + +# Please keep the list sorted. + +Alex Reece +Google Inc. diff --git a/vendor/go.uber.org/mock/LICENSE b/vendor/go.uber.org/mock/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/vendor/go.uber.org/mock/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/go.uber.org/mock/gomock/call.go b/vendor/go.uber.org/mock/gomock/call.go new file mode 100644 index 0000000..e1ea826 --- /dev/null +++ b/vendor/go.uber.org/mock/gomock/call.go @@ -0,0 +1,508 @@ +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomock + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +// Call represents an expected call to a mock. +type Call struct { + t TestHelper // for triggering test failures on invalid call setup + + receiver any // the receiver of the method call + method string // the name of the method + methodType reflect.Type // the type of the method + args []Matcher // the args + origin string // file and line number of call setup + + preReqs []*Call // prerequisite calls + + // Expectations + minCalls, maxCalls int + + numCalls int // actual number made + + // actions are called when this Call is called. Each action gets the args and + // can set the return values by returning a non-nil slice. Actions run in the + // order they are created. + actions []func([]any) []any +} + +// newCall creates a *Call. It requires the method type in order to support +// unexported methods. +func newCall(t TestHelper, receiver any, method string, methodType reflect.Type, args ...any) *Call { + t.Helper() + + // TODO: check arity, types. + mArgs := make([]Matcher, len(args)) + for i, arg := range args { + if m, ok := arg.(Matcher); ok { + mArgs[i] = m + } else if arg == nil { + // Handle nil specially so that passing a nil interface value + // will match the typed nils of concrete args. + mArgs[i] = Nil() + } else { + mArgs[i] = Eq(arg) + } + } + + // callerInfo's skip should be updated if the number of calls between the user's test + // and this line changes, i.e. this code is wrapped in another anonymous function. + // 0 is us, 1 is RecordCallWithMethodType(), 2 is the generated recorder, and 3 is the user's test. + origin := callerInfo(3) + actions := []func([]any) []any{func([]any) []any { + // Synthesize the zero value for each of the return args' types. + rets := make([]any, methodType.NumOut()) + for i := 0; i < methodType.NumOut(); i++ { + rets[i] = reflect.Zero(methodType.Out(i)).Interface() + } + return rets + }} + return &Call{t: t, receiver: receiver, method: method, methodType: methodType, + args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions} +} + +// AnyTimes allows the expectation to be called 0 or more times +func (c *Call) AnyTimes() *Call { + c.minCalls, c.maxCalls = 0, 1e8 // close enough to infinity + return c +} + +// MinTimes requires the call to occur at least n times. If AnyTimes or MaxTimes have not been called or if MaxTimes +// was previously called with 1, MinTimes also sets the maximum number of calls to infinity. +func (c *Call) MinTimes(n int) *Call { + c.minCalls = n + if c.maxCalls == 1 { + c.maxCalls = 1e8 + } + return c +} + +// MaxTimes limits the number of calls to n times. If AnyTimes or MinTimes have not been called or if MinTimes was +// previously called with 1, MaxTimes also sets the minimum number of calls to 0. +func (c *Call) MaxTimes(n int) *Call { + c.maxCalls = n + if c.minCalls == 1 { + c.minCalls = 0 + } + return c +} + +// DoAndReturn declares the action to run when the call is matched. +// The return values from this function are returned by the mocked function. +// It takes an any argument to support n-arity functions. +// The anonymous function must match the function signature mocked method. +func (c *Call) DoAndReturn(f any) *Call { + // TODO: Check arity and types here, rather than dying badly elsewhere. + v := reflect.ValueOf(f) + + c.addAction(func(args []any) []any { + c.t.Helper() + ft := v.Type() + if c.methodType.NumIn() != ft.NumIn() { + if ft.IsVariadic() { + c.t.Fatalf("wrong number of arguments in DoAndReturn func for %T.%v The function signature must match the mocked method, a variadic function cannot be used.", + c.receiver, c.method) + } else { + c.t.Fatalf("wrong number of arguments in DoAndReturn func for %T.%v: got %d, want %d [%s]", + c.receiver, c.method, ft.NumIn(), c.methodType.NumIn(), c.origin) + } + return nil + } + vArgs := make([]reflect.Value, len(args)) + for i := 0; i < len(args); i++ { + if args[i] != nil { + vArgs[i] = reflect.ValueOf(args[i]) + } else { + // Use the zero value for the arg. + vArgs[i] = reflect.Zero(ft.In(i)) + } + } + vRets := v.Call(vArgs) + rets := make([]any, len(vRets)) + for i, ret := range vRets { + rets[i] = ret.Interface() + } + return rets + }) + return c +} + +// Do declares the action to run when the call is matched. The function's +// return values are ignored to retain backward compatibility. To use the +// return values call DoAndReturn. +// It takes an any argument to support n-arity functions. +// The anonymous function must match the function signature mocked method. +func (c *Call) Do(f any) *Call { + // TODO: Check arity and types here, rather than dying badly elsewhere. + v := reflect.ValueOf(f) + + c.addAction(func(args []any) []any { + c.t.Helper() + ft := v.Type() + if c.methodType.NumIn() != ft.NumIn() { + if ft.IsVariadic() { + c.t.Fatalf("wrong number of arguments in Do func for %T.%v The function signature must match the mocked method, a variadic function cannot be used.", + c.receiver, c.method) + } else { + c.t.Fatalf("wrong number of arguments in Do func for %T.%v: got %d, want %d [%s]", + c.receiver, c.method, ft.NumIn(), c.methodType.NumIn(), c.origin) + } + return nil + } + vArgs := make([]reflect.Value, len(args)) + for i := 0; i < len(args); i++ { + if args[i] != nil { + vArgs[i] = reflect.ValueOf(args[i]) + } else { + // Use the zero value for the arg. + vArgs[i] = reflect.Zero(ft.In(i)) + } + } + v.Call(vArgs) + return nil + }) + return c +} + +// Return declares the values to be returned by the mocked function call. +func (c *Call) Return(rets ...any) *Call { + c.t.Helper() + + mt := c.methodType + if len(rets) != mt.NumOut() { + c.t.Fatalf("wrong number of arguments to Return for %T.%v: got %d, want %d [%s]", + c.receiver, c.method, len(rets), mt.NumOut(), c.origin) + } + for i, ret := range rets { + if got, want := reflect.TypeOf(ret), mt.Out(i); got == want { + // Identical types; nothing to do. + } else if got == nil { + // Nil needs special handling. + switch want.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + // ok + default: + c.t.Fatalf("argument %d to Return for %T.%v is nil, but %v is not nillable [%s]", + i, c.receiver, c.method, want, c.origin) + } + } else if got.AssignableTo(want) { + // Assignable type relation. Make the assignment now so that the generated code + // can return the values with a type assertion. + v := reflect.New(want).Elem() + v.Set(reflect.ValueOf(ret)) + rets[i] = v.Interface() + } else { + c.t.Fatalf("wrong type of argument %d to Return for %T.%v: %v is not assignable to %v [%s]", + i, c.receiver, c.method, got, want, c.origin) + } + } + + c.addAction(func([]any) []any { + return rets + }) + + return c +} + +// Times declares the exact number of times a function call is expected to be executed. +func (c *Call) Times(n int) *Call { + c.minCalls, c.maxCalls = n, n + return c +} + +// SetArg declares an action that will set the nth argument's value, +// indirected through a pointer. Or, in the case of a slice and map, SetArg +// will copy value's elements/key-value pairs into the nth argument. +func (c *Call) SetArg(n int, value any) *Call { + c.t.Helper() + + mt := c.methodType + // TODO: This will break on variadic methods. + // We will need to check those at invocation time. + if n < 0 || n >= mt.NumIn() { + c.t.Fatalf("SetArg(%d, ...) called for a method with %d args [%s]", + n, mt.NumIn(), c.origin) + } + // Permit setting argument through an interface. + // In the interface case, we don't (nay, can't) check the type here. + at := mt.In(n) + switch at.Kind() { + case reflect.Ptr: + dt := at.Elem() + if vt := reflect.TypeOf(value); !vt.AssignableTo(dt) { + c.t.Fatalf("SetArg(%d, ...) argument is a %v, not assignable to %v [%s]", + n, vt, dt, c.origin) + } + case reflect.Interface: + // nothing to do + case reflect.Slice: + // nothing to do + case reflect.Map: + // nothing to do + default: + c.t.Fatalf("SetArg(%d, ...) referring to argument of non-pointer non-interface non-slice non-map type %v [%s]", + n, at, c.origin) + } + + c.addAction(func(args []any) []any { + v := reflect.ValueOf(value) + switch reflect.TypeOf(args[n]).Kind() { + case reflect.Slice: + setSlice(args[n], v) + case reflect.Map: + setMap(args[n], v) + default: + reflect.ValueOf(args[n]).Elem().Set(v) + } + return nil + }) + return c +} + +// isPreReq returns true if other is a direct or indirect prerequisite to c. +func (c *Call) isPreReq(other *Call) bool { + for _, preReq := range c.preReqs { + if other == preReq || preReq.isPreReq(other) { + return true + } + } + return false +} + +// After declares that the call may only match after preReq has been exhausted. +func (c *Call) After(preReq *Call) *Call { + c.t.Helper() + + if c == preReq { + c.t.Fatalf("A call isn't allowed to be its own prerequisite") + } + if preReq.isPreReq(c) { + c.t.Fatalf("Loop in call order: %v is a prerequisite to %v (possibly indirectly).", c, preReq) + } + + c.preReqs = append(c.preReqs, preReq) + return c +} + +// Returns true if the minimum number of calls have been made. +func (c *Call) satisfied() bool { + return c.numCalls >= c.minCalls +} + +// Returns true if the maximum number of calls have been made. +func (c *Call) exhausted() bool { + return c.numCalls >= c.maxCalls +} + +func (c *Call) String() string { + args := make([]string, len(c.args)) + for i, arg := range c.args { + args[i] = arg.String() + } + arguments := strings.Join(args, ", ") + return fmt.Sprintf("%T.%v(%s) %s", c.receiver, c.method, arguments, c.origin) +} + +// Tests if the given call matches the expected call. +// If yes, returns nil. If no, returns error with message explaining why it does not match. +func (c *Call) matches(args []any) error { + if !c.methodType.IsVariadic() { + if len(args) != len(c.args) { + return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: %d", + c.origin, len(args), len(c.args)) + } + + for i, m := range c.args { + if !m.Matches(args[i]) { + return fmt.Errorf( + "expected call at %s doesn't match the argument at index %d.\nGot: %v\nWant: %v", + c.origin, i, formatGottenArg(m, args[i]), m, + ) + } + } + } else { + if len(c.args) < c.methodType.NumIn()-1 { + return fmt.Errorf("expected call at %s has the wrong number of matchers. Got: %d, want: %d", + c.origin, len(c.args), c.methodType.NumIn()-1) + } + if len(c.args) != c.methodType.NumIn() && len(args) != len(c.args) { + return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: %d", + c.origin, len(args), len(c.args)) + } + if len(args) < len(c.args)-1 { + return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: greater than or equal to %d", + c.origin, len(args), len(c.args)-1) + } + + for i, m := range c.args { + if i < c.methodType.NumIn()-1 { + // Non-variadic args + if !m.Matches(args[i]) { + return fmt.Errorf("expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", + c.origin, strconv.Itoa(i), formatGottenArg(m, args[i]), m) + } + continue + } + // The last arg has a possibility of a variadic argument, so let it branch + + // sample: Foo(a int, b int, c ...int) + if i < len(c.args) && i < len(args) { + if m.Matches(args[i]) { + // Got Foo(a, b, c) want Foo(matcherA, matcherB, gomock.Any()) + // Got Foo(a, b, c) want Foo(matcherA, matcherB, someSliceMatcher) + // Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC) + // Got Foo(a, b) want Foo(matcherA, matcherB) + // Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD) + continue + } + } + + // The number of actual args don't match the number of matchers, + // or the last matcher is a slice and the last arg is not. + // If this function still matches it is because the last matcher + // matches all the remaining arguments or the lack of any. + // Convert the remaining arguments, if any, into a slice of the + // expected type. + vArgsType := c.methodType.In(c.methodType.NumIn() - 1) + vArgs := reflect.MakeSlice(vArgsType, 0, len(args)-i) + for _, arg := range args[i:] { + vArgs = reflect.Append(vArgs, reflect.ValueOf(arg)) + } + if m.Matches(vArgs.Interface()) { + // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, gomock.Any()) + // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, someSliceMatcher) + // Got Foo(a, b) want Foo(matcherA, matcherB, gomock.Any()) + // Got Foo(a, b) want Foo(matcherA, matcherB, someEmptySliceMatcher) + break + } + // Wrong number of matchers or not match. Fail. + // Got Foo(a, b) want Foo(matcherA, matcherB, matcherC, matcherD) + // Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC, matcherD) + // Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD, matcherE) + // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, matcherC, matcherD) + // Got Foo(a, b, c) want Foo(matcherA, matcherB) + + return fmt.Errorf("expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v", + c.origin, strconv.Itoa(i), formatGottenArg(m, args[i:]), c.args[i]) + } + } + + // Check that all prerequisite calls have been satisfied. + for _, preReqCall := range c.preReqs { + if !preReqCall.satisfied() { + return fmt.Errorf("expected call at %s doesn't have a prerequisite call satisfied:\n%v\nshould be called before:\n%v", + c.origin, preReqCall, c) + } + } + + // Check that the call is not exhausted. + if c.exhausted() { + return fmt.Errorf("expected call at %s has already been called the max number of times", c.origin) + } + + return nil +} + +// dropPrereqs tells the expected Call to not re-check prerequisite calls any +// longer, and to return its current set. +func (c *Call) dropPrereqs() (preReqs []*Call) { + preReqs = c.preReqs + c.preReqs = nil + return +} + +func (c *Call) call() []func([]any) []any { + c.numCalls++ + return c.actions +} + +// InOrder declares that the given calls should occur in order. +// It panics if the type of any of the arguments isn't *Call or a generated +// mock with an embedded *Call. +func InOrder(args ...any) { + calls := make([]*Call, 0, len(args)) + for i := 0; i < len(args); i++ { + if call := getCall(args[i]); call != nil { + calls = append(calls, call) + continue + } + panic(fmt.Sprintf( + "invalid argument at position %d of type %T, InOrder expects *gomock.Call or generated mock types with an embedded *gomock.Call", + i, + args[i], + )) + } + for i := 1; i < len(calls); i++ { + calls[i].After(calls[i-1]) + } +} + +// getCall checks if the parameter is a *Call or a generated struct +// that wraps a *Call and returns the *Call pointer - if neither, it returns nil. +func getCall(arg any) *Call { + if call, ok := arg.(*Call); ok { + return call + } + t := reflect.ValueOf(arg) + if t.Kind() != reflect.Ptr && t.Kind() != reflect.Interface { + return nil + } + t = t.Elem() + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !f.CanInterface() { + continue + } + if call, ok := f.Interface().(*Call); ok { + return call + } + } + return nil +} + +func setSlice(arg any, v reflect.Value) { + va := reflect.ValueOf(arg) + for i := 0; i < v.Len(); i++ { + va.Index(i).Set(v.Index(i)) + } +} + +func setMap(arg any, v reflect.Value) { + va := reflect.ValueOf(arg) + for _, e := range va.MapKeys() { + va.SetMapIndex(e, reflect.Value{}) + } + for _, e := range v.MapKeys() { + va.SetMapIndex(e, v.MapIndex(e)) + } +} + +func (c *Call) addAction(action func([]any) []any) { + c.actions = append(c.actions, action) +} + +func formatGottenArg(m Matcher, arg any) string { + got := fmt.Sprintf("%v (%T)", arg, arg) + if gs, ok := m.(GotFormatter); ok { + got = gs.Got(arg) + } + return got +} diff --git a/vendor/go.uber.org/mock/gomock/callset.go b/vendor/go.uber.org/mock/gomock/callset.go new file mode 100644 index 0000000..f5cc592 --- /dev/null +++ b/vendor/go.uber.org/mock/gomock/callset.go @@ -0,0 +1,164 @@ +// Copyright 2011 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomock + +import ( + "bytes" + "errors" + "fmt" + "sync" +) + +// callSet represents a set of expected calls, indexed by receiver and method +// name. +type callSet struct { + // Calls that are still expected. + expected map[callSetKey][]*Call + expectedMu *sync.Mutex + // Calls that have been exhausted. + exhausted map[callSetKey][]*Call + // when set to true, existing call expectations are overridden when new call expectations are made + allowOverride bool +} + +// callSetKey is the key in the maps in callSet +type callSetKey struct { + receiver any + fname string +} + +func newCallSet() *callSet { + return &callSet{ + expected: make(map[callSetKey][]*Call), + expectedMu: &sync.Mutex{}, + exhausted: make(map[callSetKey][]*Call), + } +} + +func newOverridableCallSet() *callSet { + return &callSet{ + expected: make(map[callSetKey][]*Call), + expectedMu: &sync.Mutex{}, + exhausted: make(map[callSetKey][]*Call), + allowOverride: true, + } +} + +// Add adds a new expected call. +func (cs callSet) Add(call *Call) { + key := callSetKey{call.receiver, call.method} + + cs.expectedMu.Lock() + defer cs.expectedMu.Unlock() + + m := cs.expected + if call.exhausted() { + m = cs.exhausted + } + if cs.allowOverride { + m[key] = make([]*Call, 0) + } + + m[key] = append(m[key], call) +} + +// Remove removes an expected call. +func (cs callSet) Remove(call *Call) { + key := callSetKey{call.receiver, call.method} + + cs.expectedMu.Lock() + defer cs.expectedMu.Unlock() + + calls := cs.expected[key] + for i, c := range calls { + if c == call { + // maintain order for remaining calls + cs.expected[key] = append(calls[:i], calls[i+1:]...) + cs.exhausted[key] = append(cs.exhausted[key], call) + break + } + } +} + +// FindMatch searches for a matching call. Returns error with explanation message if no call matched. +func (cs callSet) FindMatch(receiver any, method string, args []any) (*Call, error) { + key := callSetKey{receiver, method} + + cs.expectedMu.Lock() + defer cs.expectedMu.Unlock() + + // Search through the expected calls. + expected := cs.expected[key] + var callsErrors bytes.Buffer + for _, call := range expected { + err := call.matches(args) + if err != nil { + _, _ = fmt.Fprintf(&callsErrors, "\n%v", err) + } else { + return call, nil + } + } + + // If we haven't found a match then search through the exhausted calls so we + // get useful error messages. + exhausted := cs.exhausted[key] + for _, call := range exhausted { + if err := call.matches(args); err != nil { + _, _ = fmt.Fprintf(&callsErrors, "\n%v", err) + continue + } + _, _ = fmt.Fprintf( + &callsErrors, "all expected calls for method %q have been exhausted", method, + ) + } + + if len(expected)+len(exhausted) == 0 { + _, _ = fmt.Fprintf(&callsErrors, "there are no expected calls of the method %q for that receiver", method) + } + + return nil, errors.New(callsErrors.String()) +} + +// Failures returns the calls that are not satisfied. +func (cs callSet) Failures() []*Call { + cs.expectedMu.Lock() + defer cs.expectedMu.Unlock() + + failures := make([]*Call, 0, len(cs.expected)) + for _, calls := range cs.expected { + for _, call := range calls { + if !call.satisfied() { + failures = append(failures, call) + } + } + } + return failures +} + +// Satisfied returns true in case all expected calls in this callSet are satisfied. +func (cs callSet) Satisfied() bool { + cs.expectedMu.Lock() + defer cs.expectedMu.Unlock() + + for _, calls := range cs.expected { + for _, call := range calls { + if !call.satisfied() { + return false + } + } + } + + return true +} diff --git a/vendor/go.uber.org/mock/gomock/controller.go b/vendor/go.uber.org/mock/gomock/controller.go new file mode 100644 index 0000000..9d17a2f --- /dev/null +++ b/vendor/go.uber.org/mock/gomock/controller.go @@ -0,0 +1,318 @@ +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomock + +import ( + "context" + "fmt" + "reflect" + "runtime" + "sync" +) + +// A TestReporter is something that can be used to report test failures. It +// is satisfied by the standard library's *testing.T. +type TestReporter interface { + Errorf(format string, args ...any) + Fatalf(format string, args ...any) +} + +// TestHelper is a TestReporter that has the Helper method. It is satisfied +// by the standard library's *testing.T. +type TestHelper interface { + TestReporter + Helper() +} + +// cleanuper is used to check if TestHelper also has the `Cleanup` method. A +// common pattern is to pass in a `*testing.T` to +// `NewController(t TestReporter)`. In Go 1.14+, `*testing.T` has a cleanup +// method. This can be utilized to call `Finish()` so the caller of this library +// does not have to. +type cleanuper interface { + Cleanup(func()) +} + +// A Controller represents the top-level control of a mock ecosystem. It +// defines the scope and lifetime of mock objects, as well as their +// expectations. It is safe to call Controller's methods from multiple +// goroutines. Each test should create a new Controller and invoke Finish via +// defer. +// +// func TestFoo(t *testing.T) { +// ctrl := gomock.NewController(t) +// // .. +// } +// +// func TestBar(t *testing.T) { +// t.Run("Sub-Test-1", st) { +// ctrl := gomock.NewController(st) +// // .. +// }) +// t.Run("Sub-Test-2", st) { +// ctrl := gomock.NewController(st) +// // .. +// }) +// }) +type Controller struct { + // T should only be called within a generated mock. It is not intended to + // be used in user code and may be changed in future versions. T is the + // TestReporter passed in when creating the Controller via NewController. + // If the TestReporter does not implement a TestHelper it will be wrapped + // with a nopTestHelper. + T TestHelper + mu sync.Mutex + expectedCalls *callSet + finished bool +} + +// NewController returns a new Controller. It is the preferred way to create a Controller. +// +// Passing [*testing.T] registers cleanup function to automatically call [Controller.Finish] +// when the test and all its subtests complete. +func NewController(t TestReporter, opts ...ControllerOption) *Controller { + h, ok := t.(TestHelper) + if !ok { + h = &nopTestHelper{t} + } + ctrl := &Controller{ + T: h, + expectedCalls: newCallSet(), + } + for _, opt := range opts { + opt.apply(ctrl) + } + if c, ok := isCleanuper(ctrl.T); ok { + c.Cleanup(func() { + ctrl.T.Helper() + ctrl.finish(true, nil) + }) + } + + return ctrl +} + +// ControllerOption configures how a Controller should behave. +type ControllerOption interface { + apply(*Controller) +} + +type overridableExpectationsOption struct{} + +// WithOverridableExpectations allows for overridable call expectations +// i.e., subsequent call expectations override existing call expectations +func WithOverridableExpectations() overridableExpectationsOption { + return overridableExpectationsOption{} +} + +func (o overridableExpectationsOption) apply(ctrl *Controller) { + ctrl.expectedCalls = newOverridableCallSet() +} + +type cancelReporter struct { + t TestHelper + cancel func() +} + +func (r *cancelReporter) Errorf(format string, args ...any) { + r.t.Errorf(format, args...) +} +func (r *cancelReporter) Fatalf(format string, args ...any) { + defer r.cancel() + r.t.Fatalf(format, args...) +} + +func (r *cancelReporter) Helper() { + r.t.Helper() +} + +// WithContext returns a new Controller and a Context, which is cancelled on any +// fatal failure. +func WithContext(ctx context.Context, t TestReporter) (*Controller, context.Context) { + h, ok := t.(TestHelper) + if !ok { + h = &nopTestHelper{t: t} + } + + ctx, cancel := context.WithCancel(ctx) + return NewController(&cancelReporter{t: h, cancel: cancel}), ctx +} + +type nopTestHelper struct { + t TestReporter +} + +func (h *nopTestHelper) Errorf(format string, args ...any) { + h.t.Errorf(format, args...) +} +func (h *nopTestHelper) Fatalf(format string, args ...any) { + h.t.Fatalf(format, args...) +} + +func (h nopTestHelper) Helper() {} + +// RecordCall is called by a mock. It should not be called by user code. +func (ctrl *Controller) RecordCall(receiver any, method string, args ...any) *Call { + ctrl.T.Helper() + + recv := reflect.ValueOf(receiver) + for i := 0; i < recv.Type().NumMethod(); i++ { + if recv.Type().Method(i).Name == method { + return ctrl.RecordCallWithMethodType(receiver, method, recv.Method(i).Type(), args...) + } + } + ctrl.T.Fatalf("gomock: failed finding method %s on %T", method, receiver) + panic("unreachable") +} + +// RecordCallWithMethodType is called by a mock. It should not be called by user code. +func (ctrl *Controller) RecordCallWithMethodType(receiver any, method string, methodType reflect.Type, args ...any) *Call { + ctrl.T.Helper() + + call := newCall(ctrl.T, receiver, method, methodType, args...) + + ctrl.mu.Lock() + defer ctrl.mu.Unlock() + ctrl.expectedCalls.Add(call) + + return call +} + +// Call is called by a mock. It should not be called by user code. +func (ctrl *Controller) Call(receiver any, method string, args ...any) []any { + ctrl.T.Helper() + + // Nest this code so we can use defer to make sure the lock is released. + actions := func() []func([]any) []any { + ctrl.T.Helper() + ctrl.mu.Lock() + defer ctrl.mu.Unlock() + + expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args) + if err != nil { + // callerInfo's skip should be updated if the number of calls between the user's test + // and this line changes, i.e. this code is wrapped in another anonymous function. + // 0 is us, 1 is controller.Call(), 2 is the generated mock, and 3 is the user's test. + origin := callerInfo(3) + ctrl.T.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err) + } + + // Two things happen here: + // * the matching call no longer needs to check prerequite calls, + // * and the prerequite calls are no longer expected, so remove them. + preReqCalls := expected.dropPrereqs() + for _, preReqCall := range preReqCalls { + ctrl.expectedCalls.Remove(preReqCall) + } + + actions := expected.call() + if expected.exhausted() { + ctrl.expectedCalls.Remove(expected) + } + return actions + }() + + var rets []any + for _, action := range actions { + if r := action(args); r != nil { + rets = r + } + } + + return rets +} + +// Finish checks to see if all the methods that were expected to be called were called. +// It is not idempotent and therefore can only be invoked once. +func (ctrl *Controller) Finish() { + // If we're currently panicking, probably because this is a deferred call. + // This must be recovered in the deferred function. + err := recover() + ctrl.finish(false, err) +} + +// Satisfied returns whether all expected calls bound to this Controller have been satisfied. +// Calling Finish is then guaranteed to not fail due to missing calls. +func (ctrl *Controller) Satisfied() bool { + ctrl.mu.Lock() + defer ctrl.mu.Unlock() + return ctrl.expectedCalls.Satisfied() +} + +func (ctrl *Controller) finish(cleanup bool, panicErr any) { + ctrl.T.Helper() + + ctrl.mu.Lock() + defer ctrl.mu.Unlock() + + if ctrl.finished { + if _, ok := isCleanuper(ctrl.T); !ok { + ctrl.T.Fatalf("Controller.Finish was called more than once. It has to be called exactly once.") + } + return + } + ctrl.finished = true + + // Short-circuit, pass through the panic. + if panicErr != nil { + panic(panicErr) + } + + // Check that all remaining expected calls are satisfied. + failures := ctrl.expectedCalls.Failures() + for _, call := range failures { + ctrl.T.Errorf("missing call(s) to %v", call) + } + if len(failures) != 0 { + if !cleanup { + ctrl.T.Fatalf("aborting test due to missing call(s)") + return + } + ctrl.T.Errorf("aborting test due to missing call(s)") + } +} + +// callerInfo returns the file:line of the call site. skip is the number +// of stack frames to skip when reporting. 0 is callerInfo's call site. +func callerInfo(skip int) string { + if _, file, line, ok := runtime.Caller(skip + 1); ok { + return fmt.Sprintf("%s:%d", file, line) + } + return "unknown file" +} + +// isCleanuper checks it if t's base TestReporter has a Cleanup method. +func isCleanuper(t TestReporter) (cleanuper, bool) { + tr := unwrapTestReporter(t) + c, ok := tr.(cleanuper) + return c, ok +} + +// unwrapTestReporter unwraps TestReporter to the base implementation. +func unwrapTestReporter(t TestReporter) TestReporter { + tr := t + switch nt := t.(type) { + case *cancelReporter: + tr = nt.t + if h, check := tr.(*nopTestHelper); check { + tr = h.t + } + case *nopTestHelper: + tr = nt.t + default: + // not wrapped + } + return tr +} diff --git a/vendor/go.uber.org/mock/gomock/doc.go b/vendor/go.uber.org/mock/gomock/doc.go new file mode 100644 index 0000000..696dda3 --- /dev/null +++ b/vendor/go.uber.org/mock/gomock/doc.go @@ -0,0 +1,60 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package gomock is a mock framework for Go. +// +// Standard usage: +// +// (1) Define an interface that you wish to mock. +// type MyInterface interface { +// SomeMethod(x int64, y string) +// } +// (2) Use mockgen to generate a mock from the interface. +// (3) Use the mock in a test: +// func TestMyThing(t *testing.T) { +// mockCtrl := gomock.NewController(t) +// mockObj := something.NewMockMyInterface(mockCtrl) +// mockObj.EXPECT().SomeMethod(4, "blah") +// // pass mockObj to a real object and play with it. +// } +// +// By default, expected calls are not enforced to run in any particular order. +// Call order dependency can be enforced by use of InOrder and/or Call.After. +// Call.After can create more varied call order dependencies, but InOrder is +// often more convenient. +// +// The following examples create equivalent call order dependencies. +// +// Example of using Call.After to chain expected call order: +// +// firstCall := mockObj.EXPECT().SomeMethod(1, "first") +// secondCall := mockObj.EXPECT().SomeMethod(2, "second").After(firstCall) +// mockObj.EXPECT().SomeMethod(3, "third").After(secondCall) +// +// Example of using InOrder to declare expected call order: +// +// gomock.InOrder( +// mockObj.EXPECT().SomeMethod(1, "first"), +// mockObj.EXPECT().SomeMethod(2, "second"), +// mockObj.EXPECT().SomeMethod(3, "third"), +// ) +// +// The standard TestReporter most users will pass to `NewController` is a +// `*testing.T` from the context of the test. Note that this will use the +// standard `t.Error` and `t.Fatal` methods to report what happened in the test. +// In some cases this can leave your testing package in a weird state if global +// state is used since `t.Fatal` is like calling panic in the middle of a +// function. In these cases it is recommended that you pass in your own +// `TestReporter`. +package gomock diff --git a/vendor/go.uber.org/mock/gomock/matchers.go b/vendor/go.uber.org/mock/gomock/matchers.go new file mode 100644 index 0000000..c172550 --- /dev/null +++ b/vendor/go.uber.org/mock/gomock/matchers.go @@ -0,0 +1,443 @@ +// Copyright 2010 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gomock + +import ( + "fmt" + "reflect" + "regexp" + "strings" +) + +// A Matcher is a representation of a class of values. +// It is used to represent the valid or expected arguments to a mocked method. +type Matcher interface { + // Matches returns whether x is a match. + Matches(x any) bool + + // String describes what the matcher matches. + String() string +} + +// WantFormatter modifies the given Matcher's String() method to the given +// Stringer. This allows for control on how the "Want" is formatted when +// printing . +func WantFormatter(s fmt.Stringer, m Matcher) Matcher { + type matcher interface { + Matches(x any) bool + } + + return struct { + matcher + fmt.Stringer + }{ + matcher: m, + Stringer: s, + } +} + +// StringerFunc type is an adapter to allow the use of ordinary functions as +// a Stringer. If f is a function with the appropriate signature, +// StringerFunc(f) is a Stringer that calls f. +type StringerFunc func() string + +// String implements fmt.Stringer. +func (f StringerFunc) String() string { + return f() +} + +// GotFormatter is used to better print failure messages. If a matcher +// implements GotFormatter, it will use the result from Got when printing +// the failure message. +type GotFormatter interface { + // Got is invoked with the received value. The result is used when + // printing the failure message. + Got(got any) string +} + +// GotFormatterFunc type is an adapter to allow the use of ordinary +// functions as a GotFormatter. If f is a function with the appropriate +// signature, GotFormatterFunc(f) is a GotFormatter that calls f. +type GotFormatterFunc func(got any) string + +// Got implements GotFormatter. +func (f GotFormatterFunc) Got(got any) string { + return f(got) +} + +// GotFormatterAdapter attaches a GotFormatter to a Matcher. +func GotFormatterAdapter(s GotFormatter, m Matcher) Matcher { + return struct { + GotFormatter + Matcher + }{ + GotFormatter: s, + Matcher: m, + } +} + +type anyMatcher struct{} + +func (anyMatcher) Matches(any) bool { + return true +} + +func (anyMatcher) String() string { + return "is anything" +} + +type condMatcher struct { + fn func(x any) bool +} + +func (c condMatcher) Matches(x any) bool { + return c.fn(x) +} + +func (condMatcher) String() string { + return "adheres to a custom condition" +} + +type eqMatcher struct { + x any +} + +func (e eqMatcher) Matches(x any) bool { + // In case, some value is nil + if e.x == nil || x == nil { + return reflect.DeepEqual(e.x, x) + } + + // Check if types assignable and convert them to common type + x1Val := reflect.ValueOf(e.x) + x2Val := reflect.ValueOf(x) + + if x1Val.Type().AssignableTo(x2Val.Type()) { + x1ValConverted := x1Val.Convert(x2Val.Type()) + return reflect.DeepEqual(x1ValConverted.Interface(), x2Val.Interface()) + } + + return false +} + +func (e eqMatcher) String() string { + return fmt.Sprintf("is equal to %v (%T)", e.x, e.x) +} + +type nilMatcher struct{} + +func (nilMatcher) Matches(x any) bool { + if x == nil { + return true + } + + v := reflect.ValueOf(x) + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, + reflect.Ptr, reflect.Slice: + return v.IsNil() + } + + return false +} + +func (nilMatcher) String() string { + return "is nil" +} + +type notMatcher struct { + m Matcher +} + +func (n notMatcher) Matches(x any) bool { + return !n.m.Matches(x) +} + +func (n notMatcher) String() string { + return "not(" + n.m.String() + ")" +} + +type regexMatcher struct { + regex *regexp.Regexp +} + +func (m regexMatcher) Matches(x any) bool { + switch t := x.(type) { + case string: + return m.regex.MatchString(t) + case []byte: + return m.regex.Match(t) + default: + return false + } +} + +func (m regexMatcher) String() string { + return "matches regex " + m.regex.String() +} + +type assignableToTypeOfMatcher struct { + targetType reflect.Type +} + +func (m assignableToTypeOfMatcher) Matches(x any) bool { + return reflect.TypeOf(x).AssignableTo(m.targetType) +} + +func (m assignableToTypeOfMatcher) String() string { + return "is assignable to " + m.targetType.Name() +} + +type anyOfMatcher struct { + matchers []Matcher +} + +func (am anyOfMatcher) Matches(x any) bool { + for _, m := range am.matchers { + if m.Matches(x) { + return true + } + } + return false +} + +func (am anyOfMatcher) String() string { + ss := make([]string, 0, len(am.matchers)) + for _, matcher := range am.matchers { + ss = append(ss, matcher.String()) + } + return strings.Join(ss, " | ") +} + +type allMatcher struct { + matchers []Matcher +} + +func (am allMatcher) Matches(x any) bool { + for _, m := range am.matchers { + if !m.Matches(x) { + return false + } + } + return true +} + +func (am allMatcher) String() string { + ss := make([]string, 0, len(am.matchers)) + for _, matcher := range am.matchers { + ss = append(ss, matcher.String()) + } + return strings.Join(ss, "; ") +} + +type lenMatcher struct { + i int +} + +func (m lenMatcher) Matches(x any) bool { + v := reflect.ValueOf(x) + switch v.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == m.i + default: + return false + } +} + +func (m lenMatcher) String() string { + return fmt.Sprintf("has length %d", m.i) +} + +type inAnyOrderMatcher struct { + x any +} + +func (m inAnyOrderMatcher) Matches(x any) bool { + given, ok := m.prepareValue(x) + if !ok { + return false + } + wanted, ok := m.prepareValue(m.x) + if !ok { + return false + } + + if given.Len() != wanted.Len() { + return false + } + + usedFromGiven := make([]bool, given.Len()) + foundFromWanted := make([]bool, wanted.Len()) + for i := 0; i < wanted.Len(); i++ { + wantedMatcher := Eq(wanted.Index(i).Interface()) + for j := 0; j < given.Len(); j++ { + if usedFromGiven[j] { + continue + } + if wantedMatcher.Matches(given.Index(j).Interface()) { + foundFromWanted[i] = true + usedFromGiven[j] = true + break + } + } + } + + missingFromWanted := 0 + for _, found := range foundFromWanted { + if !found { + missingFromWanted++ + } + } + extraInGiven := 0 + for _, used := range usedFromGiven { + if !used { + extraInGiven++ + } + } + + return extraInGiven == 0 && missingFromWanted == 0 +} + +func (m inAnyOrderMatcher) prepareValue(x any) (reflect.Value, bool) { + xValue := reflect.ValueOf(x) + switch xValue.Kind() { + case reflect.Slice, reflect.Array: + return xValue, true + default: + return reflect.Value{}, false + } +} + +func (m inAnyOrderMatcher) String() string { + return fmt.Sprintf("has the same elements as %v", m.x) +} + +// Constructors + +// All returns a composite Matcher that returns true if and only all of the +// matchers return true. +func All(ms ...Matcher) Matcher { return allMatcher{ms} } + +// Any returns a matcher that always matches. +func Any() Matcher { return anyMatcher{} } + +// Cond returns a matcher that matches when the given function returns true +// after passing it the parameter to the mock function. +// This is particularly useful in case you want to match over a field of a custom struct, or dynamic logic. +// +// Example usage: +// +// Cond(func(x any){return x.(int) == 1}).Matches(1) // returns true +// Cond(func(x any){return x.(int) == 2}).Matches(1) // returns false +func Cond(fn func(x any) bool) Matcher { return condMatcher{fn} } + +// AnyOf returns a composite Matcher that returns true if at least one of the +// matchers returns true. +// +// Example usage: +// +// AnyOf(1, 2, 3).Matches(2) // returns true +// AnyOf(1, 2, 3).Matches(10) // returns false +// AnyOf(Nil(), Len(2)).Matches(nil) // returns true +// AnyOf(Nil(), Len(2)).Matches("hi") // returns true +// AnyOf(Nil(), Len(2)).Matches("hello") // returns false +func AnyOf(xs ...any) Matcher { + ms := make([]Matcher, 0, len(xs)) + for _, x := range xs { + if m, ok := x.(Matcher); ok { + ms = append(ms, m) + } else { + ms = append(ms, Eq(x)) + } + } + return anyOfMatcher{ms} +} + +// Eq returns a matcher that matches on equality. +// +// Example usage: +// +// Eq(5).Matches(5) // returns true +// Eq(5).Matches(4) // returns false +func Eq(x any) Matcher { return eqMatcher{x} } + +// Len returns a matcher that matches on length. This matcher returns false if +// is compared to a type that is not an array, chan, map, slice, or string. +func Len(i int) Matcher { + return lenMatcher{i} +} + +// Nil returns a matcher that matches if the received value is nil. +// +// Example usage: +// +// var x *bytes.Buffer +// Nil().Matches(x) // returns true +// x = &bytes.Buffer{} +// Nil().Matches(x) // returns false +func Nil() Matcher { return nilMatcher{} } + +// Not reverses the results of its given child matcher. +// +// Example usage: +// +// Not(Eq(5)).Matches(4) // returns true +// Not(Eq(5)).Matches(5) // returns false +func Not(x any) Matcher { + if m, ok := x.(Matcher); ok { + return notMatcher{m} + } + return notMatcher{Eq(x)} +} + +// Regex checks whether parameter matches the associated regex. +// +// Example usage: +// +// Regex("[0-9]{2}:[0-9]{2}").Matches("23:02") // returns true +// Regex("[0-9]{2}:[0-9]{2}").Matches([]byte{'2', '3', ':', '0', '2'}) // returns true +// Regex("[0-9]{2}:[0-9]{2}").Matches("hello world") // returns false +// Regex("[0-9]{2}").Matches(21) // returns false as it's not a valid type +func Regex(regexStr string) Matcher { + return regexMatcher{regex: regexp.MustCompile(regexStr)} +} + +// AssignableToTypeOf is a Matcher that matches if the parameter to the mock +// function is assignable to the type of the parameter to this function. +// +// Example usage: +// +// var s fmt.Stringer = &bytes.Buffer{} +// AssignableToTypeOf(s).Matches(time.Second) // returns true +// AssignableToTypeOf(s).Matches(99) // returns false +// +// var ctx = reflect.TypeOf((*context.Context)(nil)).Elem() +// AssignableToTypeOf(ctx).Matches(context.Background()) // returns true +func AssignableToTypeOf(x any) Matcher { + if xt, ok := x.(reflect.Type); ok { + return assignableToTypeOfMatcher{xt} + } + return assignableToTypeOfMatcher{reflect.TypeOf(x)} +} + +// InAnyOrder is a Matcher that returns true for collections of the same elements ignoring the order. +// +// Example usage: +// +// InAnyOrder([]int{1, 2, 3}).Matches([]int{1, 3, 2}) // returns true +// InAnyOrder([]int{1, 2, 3}).Matches([]int{1, 2}) // returns false +func InAnyOrder(x any) Matcher { + return inAnyOrderMatcher{x} +} diff --git a/vendor/golang.org/x/sys/windows/registry/key.go b/vendor/golang.org/x/sys/windows/registry/key.go new file mode 100644 index 0000000..6c8d97b --- /dev/null +++ b/vendor/golang.org/x/sys/windows/registry/key.go @@ -0,0 +1,206 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows +// +build windows + +// Package registry provides access to the Windows registry. +// +// Here is a simple example, opening a registry key and reading a string value from it. +// +// k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE) +// if err != nil { +// log.Fatal(err) +// } +// defer k.Close() +// +// s, _, err := k.GetStringValue("SystemRoot") +// if err != nil { +// log.Fatal(err) +// } +// fmt.Printf("Windows system root is %q\n", s) +package registry + +import ( + "io" + "runtime" + "syscall" + "time" +) + +const ( + // Registry key security and access rights. + // See https://msdn.microsoft.com/en-us/library/windows/desktop/ms724878.aspx + // for details. + ALL_ACCESS = 0xf003f + CREATE_LINK = 0x00020 + CREATE_SUB_KEY = 0x00004 + ENUMERATE_SUB_KEYS = 0x00008 + EXECUTE = 0x20019 + NOTIFY = 0x00010 + QUERY_VALUE = 0x00001 + READ = 0x20019 + SET_VALUE = 0x00002 + WOW64_32KEY = 0x00200 + WOW64_64KEY = 0x00100 + WRITE = 0x20006 +) + +// Key is a handle to an open Windows registry key. +// Keys can be obtained by calling OpenKey; there are +// also some predefined root keys such as CURRENT_USER. +// Keys can be used directly in the Windows API. +type Key syscall.Handle + +const ( + // Windows defines some predefined root keys that are always open. + // An application can use these keys as entry points to the registry. + // Normally these keys are used in OpenKey to open new keys, + // but they can also be used anywhere a Key is required. + CLASSES_ROOT = Key(syscall.HKEY_CLASSES_ROOT) + CURRENT_USER = Key(syscall.HKEY_CURRENT_USER) + LOCAL_MACHINE = Key(syscall.HKEY_LOCAL_MACHINE) + USERS = Key(syscall.HKEY_USERS) + CURRENT_CONFIG = Key(syscall.HKEY_CURRENT_CONFIG) + PERFORMANCE_DATA = Key(syscall.HKEY_PERFORMANCE_DATA) +) + +// Close closes open key k. +func (k Key) Close() error { + return syscall.RegCloseKey(syscall.Handle(k)) +} + +// OpenKey opens a new key with path name relative to key k. +// It accepts any open key, including CURRENT_USER and others, +// and returns the new key and an error. +// The access parameter specifies desired access rights to the +// key to be opened. +func OpenKey(k Key, path string, access uint32) (Key, error) { + p, err := syscall.UTF16PtrFromString(path) + if err != nil { + return 0, err + } + var subkey syscall.Handle + err = syscall.RegOpenKeyEx(syscall.Handle(k), p, 0, access, &subkey) + if err != nil { + return 0, err + } + return Key(subkey), nil +} + +// OpenRemoteKey opens a predefined registry key on another +// computer pcname. The key to be opened is specified by k, but +// can only be one of LOCAL_MACHINE, PERFORMANCE_DATA or USERS. +// If pcname is "", OpenRemoteKey returns local computer key. +func OpenRemoteKey(pcname string, k Key) (Key, error) { + var err error + var p *uint16 + if pcname != "" { + p, err = syscall.UTF16PtrFromString(`\\` + pcname) + if err != nil { + return 0, err + } + } + var remoteKey syscall.Handle + err = regConnectRegistry(p, syscall.Handle(k), &remoteKey) + if err != nil { + return 0, err + } + return Key(remoteKey), nil +} + +// ReadSubKeyNames returns the names of subkeys of key k. +// The parameter n controls the number of returned names, +// analogous to the way os.File.Readdirnames works. +func (k Key) ReadSubKeyNames(n int) ([]string, error) { + // RegEnumKeyEx must be called repeatedly and to completion. + // During this time, this goroutine cannot migrate away from + // its current thread. See https://golang.org/issue/49320 and + // https://golang.org/issue/49466. + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + names := make([]string, 0) + // Registry key size limit is 255 bytes and described there: + // https://msdn.microsoft.com/library/windows/desktop/ms724872.aspx + buf := make([]uint16, 256) //plus extra room for terminating zero byte +loopItems: + for i := uint32(0); ; i++ { + if n > 0 { + if len(names) == n { + return names, nil + } + } + l := uint32(len(buf)) + for { + err := syscall.RegEnumKeyEx(syscall.Handle(k), i, &buf[0], &l, nil, nil, nil, nil) + if err == nil { + break + } + if err == syscall.ERROR_MORE_DATA { + // Double buffer size and try again. + l = uint32(2 * len(buf)) + buf = make([]uint16, l) + continue + } + if err == _ERROR_NO_MORE_ITEMS { + break loopItems + } + return names, err + } + names = append(names, syscall.UTF16ToString(buf[:l])) + } + if n > len(names) { + return names, io.EOF + } + return names, nil +} + +// CreateKey creates a key named path under open key k. +// CreateKey returns the new key and a boolean flag that reports +// whether the key already existed. +// The access parameter specifies the access rights for the key +// to be created. +func CreateKey(k Key, path string, access uint32) (newk Key, openedExisting bool, err error) { + var h syscall.Handle + var d uint32 + err = regCreateKeyEx(syscall.Handle(k), syscall.StringToUTF16Ptr(path), + 0, nil, _REG_OPTION_NON_VOLATILE, access, nil, &h, &d) + if err != nil { + return 0, false, err + } + return Key(h), d == _REG_OPENED_EXISTING_KEY, nil +} + +// DeleteKey deletes the subkey path of key k and its values. +func DeleteKey(k Key, path string) error { + return regDeleteKey(syscall.Handle(k), syscall.StringToUTF16Ptr(path)) +} + +// A KeyInfo describes the statistics of a key. It is returned by Stat. +type KeyInfo struct { + SubKeyCount uint32 + MaxSubKeyLen uint32 // size of the key's subkey with the longest name, in Unicode characters, not including the terminating zero byte + ValueCount uint32 + MaxValueNameLen uint32 // size of the key's longest value name, in Unicode characters, not including the terminating zero byte + MaxValueLen uint32 // longest data component among the key's values, in bytes + lastWriteTime syscall.Filetime +} + +// ModTime returns the key's last write time. +func (ki *KeyInfo) ModTime() time.Time { + return time.Unix(0, ki.lastWriteTime.Nanoseconds()) +} + +// Stat retrieves information about the open key k. +func (k Key) Stat() (*KeyInfo, error) { + var ki KeyInfo + err := syscall.RegQueryInfoKey(syscall.Handle(k), nil, nil, nil, + &ki.SubKeyCount, &ki.MaxSubKeyLen, nil, &ki.ValueCount, + &ki.MaxValueNameLen, &ki.MaxValueLen, nil, &ki.lastWriteTime) + if err != nil { + return nil, err + } + return &ki, nil +} diff --git a/vendor/golang.org/x/sys/windows/registry/mksyscall.go b/vendor/golang.org/x/sys/windows/registry/mksyscall.go new file mode 100644 index 0000000..ee74927 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/registry/mksyscall.go @@ -0,0 +1,10 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build generate +// +build generate + +package registry + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall.go diff --git a/vendor/golang.org/x/sys/windows/registry/syscall.go b/vendor/golang.org/x/sys/windows/registry/syscall.go new file mode 100644 index 0000000..4173351 --- /dev/null +++ b/vendor/golang.org/x/sys/windows/registry/syscall.go @@ -0,0 +1,33 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows +// +build windows + +package registry + +import "syscall" + +const ( + _REG_OPTION_NON_VOLATILE = 0 + + _REG_CREATED_NEW_KEY = 1 + _REG_OPENED_EXISTING_KEY = 2 + + _ERROR_NO_MORE_ITEMS syscall.Errno = 259 +) + +func LoadRegLoadMUIString() error { + return procRegLoadMUIStringW.Find() +} + +//sys regCreateKeyEx(key syscall.Handle, subkey *uint16, reserved uint32, class *uint16, options uint32, desired uint32, sa *syscall.SecurityAttributes, result *syscall.Handle, disposition *uint32) (regerrno error) = advapi32.RegCreateKeyExW +//sys regDeleteKey(key syscall.Handle, subkey *uint16) (regerrno error) = advapi32.RegDeleteKeyW +//sys regSetValueEx(key syscall.Handle, valueName *uint16, reserved uint32, vtype uint32, buf *byte, bufsize uint32) (regerrno error) = advapi32.RegSetValueExW +//sys regEnumValue(key syscall.Handle, index uint32, name *uint16, nameLen *uint32, reserved *uint32, valtype *uint32, buf *byte, buflen *uint32) (regerrno error) = advapi32.RegEnumValueW +//sys regDeleteValue(key syscall.Handle, name *uint16) (regerrno error) = advapi32.RegDeleteValueW +//sys regLoadMUIString(key syscall.Handle, name *uint16, buf *uint16, buflen uint32, buflenCopied *uint32, flags uint32, dir *uint16) (regerrno error) = advapi32.RegLoadMUIStringW +//sys regConnectRegistry(machinename *uint16, key syscall.Handle, result *syscall.Handle) (regerrno error) = advapi32.RegConnectRegistryW + +//sys expandEnvironmentStrings(src *uint16, dst *uint16, size uint32) (n uint32, err error) = kernel32.ExpandEnvironmentStringsW diff --git a/vendor/golang.org/x/sys/windows/registry/value.go b/vendor/golang.org/x/sys/windows/registry/value.go new file mode 100644 index 0000000..2789f6f --- /dev/null +++ b/vendor/golang.org/x/sys/windows/registry/value.go @@ -0,0 +1,387 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows +// +build windows + +package registry + +import ( + "errors" + "io" + "syscall" + "unicode/utf16" + "unsafe" +) + +const ( + // Registry value types. + NONE = 0 + SZ = 1 + EXPAND_SZ = 2 + BINARY = 3 + DWORD = 4 + DWORD_BIG_ENDIAN = 5 + LINK = 6 + MULTI_SZ = 7 + RESOURCE_LIST = 8 + FULL_RESOURCE_DESCRIPTOR = 9 + RESOURCE_REQUIREMENTS_LIST = 10 + QWORD = 11 +) + +var ( + // ErrShortBuffer is returned when the buffer was too short for the operation. + ErrShortBuffer = syscall.ERROR_MORE_DATA + + // ErrNotExist is returned when a registry key or value does not exist. + ErrNotExist = syscall.ERROR_FILE_NOT_FOUND + + // ErrUnexpectedType is returned by Get*Value when the value's type was unexpected. + ErrUnexpectedType = errors.New("unexpected key value type") +) + +// GetValue retrieves the type and data for the specified value associated +// with an open key k. It fills up buffer buf and returns the retrieved +// byte count n. If buf is too small to fit the stored value it returns +// ErrShortBuffer error along with the required buffer size n. +// If no buffer is provided, it returns true and actual buffer size n. +// If no buffer is provided, GetValue returns the value's type only. +// If the value does not exist, the error returned is ErrNotExist. +// +// GetValue is a low level function. If value's type is known, use the appropriate +// Get*Value function instead. +func (k Key) GetValue(name string, buf []byte) (n int, valtype uint32, err error) { + pname, err := syscall.UTF16PtrFromString(name) + if err != nil { + return 0, 0, err + } + var pbuf *byte + if len(buf) > 0 { + pbuf = (*byte)(unsafe.Pointer(&buf[0])) + } + l := uint32(len(buf)) + err = syscall.RegQueryValueEx(syscall.Handle(k), pname, nil, &valtype, pbuf, &l) + if err != nil { + return int(l), valtype, err + } + return int(l), valtype, nil +} + +func (k Key) getValue(name string, buf []byte) (data []byte, valtype uint32, err error) { + p, err := syscall.UTF16PtrFromString(name) + if err != nil { + return nil, 0, err + } + var t uint32 + n := uint32(len(buf)) + for { + err = syscall.RegQueryValueEx(syscall.Handle(k), p, nil, &t, (*byte)(unsafe.Pointer(&buf[0])), &n) + if err == nil { + return buf[:n], t, nil + } + if err != syscall.ERROR_MORE_DATA { + return nil, 0, err + } + if n <= uint32(len(buf)) { + return nil, 0, err + } + buf = make([]byte, n) + } +} + +// GetStringValue retrieves the string value for the specified +// value name associated with an open key k. It also returns the value's type. +// If value does not exist, GetStringValue returns ErrNotExist. +// If value is not SZ or EXPAND_SZ, it will return the correct value +// type and ErrUnexpectedType. +func (k Key) GetStringValue(name string) (val string, valtype uint32, err error) { + data, typ, err2 := k.getValue(name, make([]byte, 64)) + if err2 != nil { + return "", typ, err2 + } + switch typ { + case SZ, EXPAND_SZ: + default: + return "", typ, ErrUnexpectedType + } + if len(data) == 0 { + return "", typ, nil + } + u := (*[1 << 29]uint16)(unsafe.Pointer(&data[0]))[: len(data)/2 : len(data)/2] + return syscall.UTF16ToString(u), typ, nil +} + +// GetMUIStringValue retrieves the localized string value for +// the specified value name associated with an open key k. +// If the value name doesn't exist or the localized string value +// can't be resolved, GetMUIStringValue returns ErrNotExist. +// GetMUIStringValue panics if the system doesn't support +// regLoadMUIString; use LoadRegLoadMUIString to check if +// regLoadMUIString is supported before calling this function. +func (k Key) GetMUIStringValue(name string) (string, error) { + pname, err := syscall.UTF16PtrFromString(name) + if err != nil { + return "", err + } + + buf := make([]uint16, 1024) + var buflen uint32 + var pdir *uint16 + + err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir) + if err == syscall.ERROR_FILE_NOT_FOUND { // Try fallback path + + // Try to resolve the string value using the system directory as + // a DLL search path; this assumes the string value is of the form + // @[path]\dllname,-strID but with no path given, e.g. @tzres.dll,-320. + + // This approach works with tzres.dll but may have to be revised + // in the future to allow callers to provide custom search paths. + + var s string + s, err = ExpandString("%SystemRoot%\\system32\\") + if err != nil { + return "", err + } + pdir, err = syscall.UTF16PtrFromString(s) + if err != nil { + return "", err + } + + err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir) + } + + for err == syscall.ERROR_MORE_DATA { // Grow buffer if needed + if buflen <= uint32(len(buf)) { + break // Buffer not growing, assume race; break + } + buf = make([]uint16, buflen) + err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir) + } + + if err != nil { + return "", err + } + + return syscall.UTF16ToString(buf), nil +} + +// ExpandString expands environment-variable strings and replaces +// them with the values defined for the current user. +// Use ExpandString to expand EXPAND_SZ strings. +func ExpandString(value string) (string, error) { + if value == "" { + return "", nil + } + p, err := syscall.UTF16PtrFromString(value) + if err != nil { + return "", err + } + r := make([]uint16, 100) + for { + n, err := expandEnvironmentStrings(p, &r[0], uint32(len(r))) + if err != nil { + return "", err + } + if n <= uint32(len(r)) { + return syscall.UTF16ToString(r[:n]), nil + } + r = make([]uint16, n) + } +} + +// GetStringsValue retrieves the []string value for the specified +// value name associated with an open key k. It also returns the value's type. +// If value does not exist, GetStringsValue returns ErrNotExist. +// If value is not MULTI_SZ, it will return the correct value +// type and ErrUnexpectedType. +func (k Key) GetStringsValue(name string) (val []string, valtype uint32, err error) { + data, typ, err2 := k.getValue(name, make([]byte, 64)) + if err2 != nil { + return nil, typ, err2 + } + if typ != MULTI_SZ { + return nil, typ, ErrUnexpectedType + } + if len(data) == 0 { + return nil, typ, nil + } + p := (*[1 << 29]uint16)(unsafe.Pointer(&data[0]))[: len(data)/2 : len(data)/2] + if len(p) == 0 { + return nil, typ, nil + } + if p[len(p)-1] == 0 { + p = p[:len(p)-1] // remove terminating null + } + val = make([]string, 0, 5) + from := 0 + for i, c := range p { + if c == 0 { + val = append(val, string(utf16.Decode(p[from:i]))) + from = i + 1 + } + } + return val, typ, nil +} + +// GetIntegerValue retrieves the integer value for the specified +// value name associated with an open key k. It also returns the value's type. +// If value does not exist, GetIntegerValue returns ErrNotExist. +// If value is not DWORD or QWORD, it will return the correct value +// type and ErrUnexpectedType. +func (k Key) GetIntegerValue(name string) (val uint64, valtype uint32, err error) { + data, typ, err2 := k.getValue(name, make([]byte, 8)) + if err2 != nil { + return 0, typ, err2 + } + switch typ { + case DWORD: + if len(data) != 4 { + return 0, typ, errors.New("DWORD value is not 4 bytes long") + } + var val32 uint32 + copy((*[4]byte)(unsafe.Pointer(&val32))[:], data) + return uint64(val32), DWORD, nil + case QWORD: + if len(data) != 8 { + return 0, typ, errors.New("QWORD value is not 8 bytes long") + } + copy((*[8]byte)(unsafe.Pointer(&val))[:], data) + return val, QWORD, nil + default: + return 0, typ, ErrUnexpectedType + } +} + +// GetBinaryValue retrieves the binary value for the specified +// value name associated with an open key k. It also returns the value's type. +// If value does not exist, GetBinaryValue returns ErrNotExist. +// If value is not BINARY, it will return the correct value +// type and ErrUnexpectedType. +func (k Key) GetBinaryValue(name string) (val []byte, valtype uint32, err error) { + data, typ, err2 := k.getValue(name, make([]byte, 64)) + if err2 != nil { + return nil, typ, err2 + } + if typ != BINARY { + return nil, typ, ErrUnexpectedType + } + return data, typ, nil +} + +func (k Key) setValue(name string, valtype uint32, data []byte) error { + p, err := syscall.UTF16PtrFromString(name) + if err != nil { + return err + } + if len(data) == 0 { + return regSetValueEx(syscall.Handle(k), p, 0, valtype, nil, 0) + } + return regSetValueEx(syscall.Handle(k), p, 0, valtype, &data[0], uint32(len(data))) +} + +// SetDWordValue sets the data and type of a name value +// under key k to value and DWORD. +func (k Key) SetDWordValue(name string, value uint32) error { + return k.setValue(name, DWORD, (*[4]byte)(unsafe.Pointer(&value))[:]) +} + +// SetQWordValue sets the data and type of a name value +// under key k to value and QWORD. +func (k Key) SetQWordValue(name string, value uint64) error { + return k.setValue(name, QWORD, (*[8]byte)(unsafe.Pointer(&value))[:]) +} + +func (k Key) setStringValue(name string, valtype uint32, value string) error { + v, err := syscall.UTF16FromString(value) + if err != nil { + return err + } + buf := (*[1 << 29]byte)(unsafe.Pointer(&v[0]))[: len(v)*2 : len(v)*2] + return k.setValue(name, valtype, buf) +} + +// SetStringValue sets the data and type of a name value +// under key k to value and SZ. The value must not contain a zero byte. +func (k Key) SetStringValue(name, value string) error { + return k.setStringValue(name, SZ, value) +} + +// SetExpandStringValue sets the data and type of a name value +// under key k to value and EXPAND_SZ. The value must not contain a zero byte. +func (k Key) SetExpandStringValue(name, value string) error { + return k.setStringValue(name, EXPAND_SZ, value) +} + +// SetStringsValue sets the data and type of a name value +// under key k to value and MULTI_SZ. The value strings +// must not contain a zero byte. +func (k Key) SetStringsValue(name string, value []string) error { + ss := "" + for _, s := range value { + for i := 0; i < len(s); i++ { + if s[i] == 0 { + return errors.New("string cannot have 0 inside") + } + } + ss += s + "\x00" + } + v := utf16.Encode([]rune(ss + "\x00")) + buf := (*[1 << 29]byte)(unsafe.Pointer(&v[0]))[: len(v)*2 : len(v)*2] + return k.setValue(name, MULTI_SZ, buf) +} + +// SetBinaryValue sets the data and type of a name value +// under key k to value and BINARY. +func (k Key) SetBinaryValue(name string, value []byte) error { + return k.setValue(name, BINARY, value) +} + +// DeleteValue removes a named value from the key k. +func (k Key) DeleteValue(name string) error { + return regDeleteValue(syscall.Handle(k), syscall.StringToUTF16Ptr(name)) +} + +// ReadValueNames returns the value names of key k. +// The parameter n controls the number of returned names, +// analogous to the way os.File.Readdirnames works. +func (k Key) ReadValueNames(n int) ([]string, error) { + ki, err := k.Stat() + if err != nil { + return nil, err + } + names := make([]string, 0, ki.ValueCount) + buf := make([]uint16, ki.MaxValueNameLen+1) // extra room for terminating null character +loopItems: + for i := uint32(0); ; i++ { + if n > 0 { + if len(names) == n { + return names, nil + } + } + l := uint32(len(buf)) + for { + err := regEnumValue(syscall.Handle(k), i, &buf[0], &l, nil, nil, nil, nil) + if err == nil { + break + } + if err == syscall.ERROR_MORE_DATA { + // Double buffer size and try again. + l = uint32(2 * len(buf)) + buf = make([]uint16, l) + continue + } + if err == _ERROR_NO_MORE_ITEMS { + break loopItems + } + return names, err + } + names = append(names, syscall.UTF16ToString(buf[:l])) + } + if n > len(names) { + return names, io.EOF + } + return names, nil +} diff --git a/vendor/golang.org/x/sys/windows/registry/zsyscall_windows.go b/vendor/golang.org/x/sys/windows/registry/zsyscall_windows.go new file mode 100644 index 0000000..fc1835d --- /dev/null +++ b/vendor/golang.org/x/sys/windows/registry/zsyscall_windows.go @@ -0,0 +1,117 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package registry + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + + procRegConnectRegistryW = modadvapi32.NewProc("RegConnectRegistryW") + procRegCreateKeyExW = modadvapi32.NewProc("RegCreateKeyExW") + procRegDeleteKeyW = modadvapi32.NewProc("RegDeleteKeyW") + procRegDeleteValueW = modadvapi32.NewProc("RegDeleteValueW") + procRegEnumValueW = modadvapi32.NewProc("RegEnumValueW") + procRegLoadMUIStringW = modadvapi32.NewProc("RegLoadMUIStringW") + procRegSetValueExW = modadvapi32.NewProc("RegSetValueExW") + procExpandEnvironmentStringsW = modkernel32.NewProc("ExpandEnvironmentStringsW") +) + +func regConnectRegistry(machinename *uint16, key syscall.Handle, result *syscall.Handle) (regerrno error) { + r0, _, _ := syscall.Syscall(procRegConnectRegistryW.Addr(), 3, uintptr(unsafe.Pointer(machinename)), uintptr(key), uintptr(unsafe.Pointer(result))) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func regCreateKeyEx(key syscall.Handle, subkey *uint16, reserved uint32, class *uint16, options uint32, desired uint32, sa *syscall.SecurityAttributes, result *syscall.Handle, disposition *uint32) (regerrno error) { + r0, _, _ := syscall.Syscall9(procRegCreateKeyExW.Addr(), 9, uintptr(key), uintptr(unsafe.Pointer(subkey)), uintptr(reserved), uintptr(unsafe.Pointer(class)), uintptr(options), uintptr(desired), uintptr(unsafe.Pointer(sa)), uintptr(unsafe.Pointer(result)), uintptr(unsafe.Pointer(disposition))) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func regDeleteKey(key syscall.Handle, subkey *uint16) (regerrno error) { + r0, _, _ := syscall.Syscall(procRegDeleteKeyW.Addr(), 2, uintptr(key), uintptr(unsafe.Pointer(subkey)), 0) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func regDeleteValue(key syscall.Handle, name *uint16) (regerrno error) { + r0, _, _ := syscall.Syscall(procRegDeleteValueW.Addr(), 2, uintptr(key), uintptr(unsafe.Pointer(name)), 0) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func regEnumValue(key syscall.Handle, index uint32, name *uint16, nameLen *uint32, reserved *uint32, valtype *uint32, buf *byte, buflen *uint32) (regerrno error) { + r0, _, _ := syscall.Syscall9(procRegEnumValueW.Addr(), 8, uintptr(key), uintptr(index), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(nameLen)), uintptr(unsafe.Pointer(reserved)), uintptr(unsafe.Pointer(valtype)), uintptr(unsafe.Pointer(buf)), uintptr(unsafe.Pointer(buflen)), 0) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func regLoadMUIString(key syscall.Handle, name *uint16, buf *uint16, buflen uint32, buflenCopied *uint32, flags uint32, dir *uint16) (regerrno error) { + r0, _, _ := syscall.Syscall9(procRegLoadMUIStringW.Addr(), 7, uintptr(key), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(buf)), uintptr(buflen), uintptr(unsafe.Pointer(buflenCopied)), uintptr(flags), uintptr(unsafe.Pointer(dir)), 0, 0) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func regSetValueEx(key syscall.Handle, valueName *uint16, reserved uint32, vtype uint32, buf *byte, bufsize uint32) (regerrno error) { + r0, _, _ := syscall.Syscall6(procRegSetValueExW.Addr(), 6, uintptr(key), uintptr(unsafe.Pointer(valueName)), uintptr(reserved), uintptr(vtype), uintptr(unsafe.Pointer(buf)), uintptr(bufsize)) + if r0 != 0 { + regerrno = syscall.Errno(r0) + } + return +} + +func expandEnvironmentStrings(src *uint16, dst *uint16, size uint32) (n uint32, err error) { + r0, _, e1 := syscall.Syscall(procExpandEnvironmentStringsW.Addr(), 3, uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(size)) + n = uint32(r0) + if n == 0 { + err = errnoErr(e1) + } + return +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 53f9f12..6b12bc3 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -3,10 +3,12 @@ github.com/Azure/azure-docker-extension/pkg/vmextension # github.com/Azure/azure-extension-platform v0.0.0-20240521173920-6b2acfda81e9 ## explicit; go 1.21 +github.com/Azure/azure-extension-platform/pkg/constants github.com/Azure/azure-extension-platform/pkg/extensionerrors github.com/Azure/azure-extension-platform/pkg/extensionevents github.com/Azure/azure-extension-platform/pkg/handlerenv github.com/Azure/azure-extension-platform/pkg/logging +github.com/Azure/azure-extension-platform/pkg/seqno github.com/Azure/azure-extension-platform/pkg/utils # github.com/cilium/ebpf v0.9.1 ## explicit; go 1.17 @@ -69,11 +71,15 @@ github.com/xeipuuv/gojsonreference # github.com/xeipuuv/gojsonschema v0.0.0-20160623135812-c539bca196be ## explicit github.com/xeipuuv/gojsonschema +# go.uber.org/mock v0.4.0 +## explicit; go 1.20 +go.uber.org/mock/gomock # golang.org/x/sys v0.2.0 ## explicit; go 1.17 golang.org/x/sys/internal/unsafeheader golang.org/x/sys/unix golang.org/x/sys/windows +golang.org/x/sys/windows/registry # google.golang.org/protobuf v1.33.0 ## explicit; go 1.17 google.golang.org/protobuf/encoding/prototext