Skip to content

Commit c9571c3

Browse files
committed
Refactor service registration (#8976)
* serivceregistration: refactor service registration logic to run later * move state check to the internal func * sr/kubernetes: update setInitialStateInternal godoc * sr/kubernetes: remove return in setInitialState * core/test: fix mockServiceRegistration * address review feedback
1 parent ae6cc3d commit c9571c3

File tree

13 files changed

+355
-253
lines changed

13 files changed

+355
-253
lines changed

command/server.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -966,9 +966,6 @@ func (c *ServerCommand) Run(args []string) int {
966966
return 1
967967
}
968968

969-
// Instantiate the wait group
970-
c.WaitGroup = &sync.WaitGroup{}
971-
972969
// Initialize the Service Discovery, if there is one
973970
var configSR sr.ServiceRegistration
974971
if config.ServiceRegistration != nil {
@@ -990,15 +987,11 @@ func (c *ServerCommand) Run(args []string) int {
990987
IsActive: false,
991988
IsPerformanceStandby: false,
992989
}
993-
configSR, err = sdFactory(config.ServiceRegistration.Config, namedSDLogger, state, config.Storage.RedirectAddr)
990+
configSR, err = sdFactory(config.ServiceRegistration.Config, namedSDLogger, state)
994991
if err != nil {
995992
c.UI.Error(fmt.Sprintf("Error initializing service_registration of type %s: %s", config.ServiceRegistration.Type, err))
996993
return 1
997994
}
998-
if err := configSR.Run(c.ShutdownCh, c.WaitGroup); err != nil {
999-
c.UI.Error(fmt.Sprintf("Error running service_registration of type %s: %s", config.ServiceRegistration.Type, err))
1000-
return 1
1001-
}
1002995
}
1003996

1004997
infoKeys := make([]string, 0, 10)
@@ -1311,7 +1304,7 @@ CLUSTER_SYNTHESIS_COMPLETE:
13111304

13121305
// If ServiceRegistration is configured, then the backend must support HA
13131306
isBackendHA := coreConfig.HAPhysical != nil && coreConfig.HAPhysical.HAEnabled()
1314-
if !c.flagDev && (coreConfig.ServiceRegistration != nil) && !isBackendHA {
1307+
if !c.flagDev && (coreConfig.GetServiceRegistration() != nil) && !isBackendHA {
13151308
c.UI.Output("service_registration is configured, but storage does not support HA")
13161309
return 1
13171310
}
@@ -1578,6 +1571,18 @@ CLUSTER_SYNTHESIS_COMPLETE:
15781571
}
15791572

15801573
// Perform initialization of HTTP server after the verifyOnly check.
1574+
1575+
// Instantiate the wait group
1576+
c.WaitGroup = &sync.WaitGroup{}
1577+
1578+
// If service discovery is available, run service discovery
1579+
if sd := coreConfig.GetServiceRegistration(); sd != nil {
1580+
if err := configSR.Run(c.ShutdownCh, c.WaitGroup, coreConfig.RedirectAddr); err != nil {
1581+
c.UI.Error(fmt.Sprintf("Error running service_registration of type %s: %s", config.ServiceRegistration.Type, err))
1582+
return 1
1583+
}
1584+
}
1585+
15811586
// If we're in Dev mode, then initialize the core
15821587
if c.flagDev && !c.flagDevSkipInit {
15831588
init, err := c.enableDev(core, coreConfig)

serviceregistration/consul/consul_service_registration.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ type serviceRegistration struct {
6868
serviceAddress *string
6969
disableRegistration bool
7070
checkTimeout time.Duration
71-
redirectAddr string
7271

7372
notifyActiveCh chan struct{}
7473
notifySealedCh chan struct{}
@@ -78,8 +77,7 @@ type serviceRegistration struct {
7877
}
7978

8079
// NewConsulServiceRegistration constructs a Consul-based ServiceRegistration.
81-
func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr.State, redirectAddr string) (sr.ServiceRegistration, error) {
82-
80+
func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr.State) (sr.ServiceRegistration, error) {
8381
// Allow admins to disable consul integration
8482
disableReg, ok := conf["disable_registration"]
8583
var disableRegistration bool
@@ -208,7 +206,6 @@ func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr.
208206
serviceAddress: serviceAddr,
209207
checkTimeout: checkTimeout,
210208
disableRegistration: disableRegistration,
211-
redirectAddr: redirectAddr,
212209

213210
notifyActiveCh: make(chan struct{}),
214211
notifySealedCh: make(chan struct{}),
@@ -221,9 +218,9 @@ func NewServiceRegistration(conf map[string]string, logger log.Logger, state sr.
221218
return c, nil
222219
}
223220

224-
func (c *serviceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup) error {
221+
func (c *serviceRegistration) Run(shutdownCh <-chan struct{}, wait *sync.WaitGroup, redirectAddr string) error {
225222
go func() {
226-
if err := c.runServiceRegistration(wait, shutdownCh, c.redirectAddr); err != nil {
223+
if err := c.runServiceRegistration(wait, shutdownCh, redirectAddr); err != nil {
227224
if c.logger.IsError() {
228225
c.logger.Error(fmt.Sprintf("error running service registration: %s", err))
229226
}
@@ -290,12 +287,12 @@ func (c *serviceRegistration) runServiceRegistration(waitGroup *sync.WaitGroup,
290287
// 'server' command will wait for the below goroutine to complete
291288
waitGroup.Add(1)
292289

293-
go c.runEventDemuxer(waitGroup, shutdownCh, redirectAddr)
290+
go c.runEventDemuxer(waitGroup, shutdownCh)
294291

295292
return nil
296293
}
297294

298-
func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh <-chan struct{}, redirectAddr string) {
295+
func (c *serviceRegistration) runEventDemuxer(waitGroup *sync.WaitGroup, shutdownCh <-chan struct{}) {
299296
// This defer statement should be executed last. So push it first.
300297
defer waitGroup.Done()
301298

serviceregistration/consul/consul_service_registration_test.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ func testConsulServiceRegistrationConfig(t *testing.T, conf *consulConf) *servic
3232
defer func() {
3333
close(shutdownCh)
3434
}()
35-
be, err := NewServiceRegistration(*conf, logger, sr.State{}, "")
35+
be, err := NewServiceRegistration(*conf, logger, sr.State{})
3636
if err != nil {
3737
t.Fatalf("Expected Consul to initialize: %v", err)
3838
}
39-
if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil {
39+
if err := be.Run(shutdownCh, &sync.WaitGroup{}, ""); err != nil {
4040
t.Fatal(err)
4141
}
4242

@@ -69,8 +69,10 @@ func TestConsul_ServiceRegistration(t *testing.T) {
6969
waitForServices := func(t *testing.T, expected map[string][]string) map[string][]string {
7070
t.Helper()
7171
// Wait for up to 10 seconds
72+
var services map[string][]string
73+
var err error
7274
for i := 0; i < 10; i++ {
73-
services, _, err := client.Catalog().Services(nil)
75+
services, _, err = client.Catalog().Services(nil)
7476
if err != nil {
7577
t.Fatal(err)
7678
}
@@ -79,7 +81,7 @@ func TestConsul_ServiceRegistration(t *testing.T) {
7981
}
8082
time.Sleep(time.Second)
8183
}
82-
t.Fatalf("Catalog Services never reached expected state %v", expected)
84+
t.Fatalf("Catalog Services never reached: got: %v, expected state: %v", services, expected)
8385
return nil
8486
}
8587

@@ -94,11 +96,11 @@ func TestConsul_ServiceRegistration(t *testing.T) {
9496
sd, err := NewServiceRegistration(map[string]string{
9597
"address": addr,
9698
"token": token,
97-
}, logger, sr.State{}, redirectAddr)
99+
}, logger, sr.State{})
98100
if err != nil {
99101
t.Fatal(err)
100102
}
101-
if err := sd.Run(shutdownCh, &sync.WaitGroup{}); err != nil {
103+
if err := sd.Run(shutdownCh, &sync.WaitGroup{}, redirectAddr); err != nil {
102104
t.Fatal(err)
103105
}
104106

@@ -167,11 +169,11 @@ func TestConsul_ServiceTags(t *testing.T) {
167169
close(shutdownCh)
168170
}()
169171

170-
be, err := NewServiceRegistration(consulConfig, logger, sr.State{}, "")
172+
be, err := NewServiceRegistration(consulConfig, logger, sr.State{})
171173
if err != nil {
172174
t.Fatal(err)
173175
}
174-
if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil {
176+
if err := be.Run(shutdownCh, &sync.WaitGroup{}, ""); err != nil {
175177
t.Fatal(err)
176178
}
177179

@@ -226,11 +228,11 @@ func TestConsul_ServiceAddress(t *testing.T) {
226228
shutdownCh := make(chan struct{})
227229
logger := logging.NewVaultLogger(log.Debug)
228230

229-
be, err := NewServiceRegistration(test.consulConfig, logger, sr.State{}, "")
231+
be, err := NewServiceRegistration(test.consulConfig, logger, sr.State{})
230232
if err != nil {
231233
t.Fatalf("expected Consul to initialize: %v", err)
232234
}
233-
if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil {
235+
if err := be.Run(shutdownCh, &sync.WaitGroup{}, ""); err != nil {
234236
t.Fatal(err)
235237
}
236238

@@ -355,7 +357,7 @@ func TestConsul_newConsulServiceRegistration(t *testing.T) {
355357
shutdownCh := make(chan struct{})
356358
logger := logging.NewVaultLogger(log.Debug)
357359

358-
be, err := NewServiceRegistration(test.consulConfig, logger, sr.State{}, "")
360+
be, err := NewServiceRegistration(test.consulConfig, logger, sr.State{})
359361
if test.fail {
360362
if err == nil {
361363
t.Fatalf(`Expected config "%s" to fail`, test.name)
@@ -365,7 +367,7 @@ func TestConsul_newConsulServiceRegistration(t *testing.T) {
365367
} else if !test.fail && err != nil {
366368
t.Fatalf("Expected config %s to not fail: %v", test.name, err)
367369
}
368-
if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil {
370+
if err := be.Run(shutdownCh, &sync.WaitGroup{}, ""); err != nil {
369371
t.Fatal(err)
370372
}
371373

@@ -559,7 +561,7 @@ func TestConsul_serviceID(t *testing.T) {
559561
shutdownCh := make(chan struct{})
560562
be, err := NewServiceRegistration(consulConf{
561563
"service": test.serviceName,
562-
}, logger, sr.State{}, "")
564+
}, logger, sr.State{})
563565
if !test.valid {
564566
if err == nil {
565567
t.Fatalf("expected an error initializing for name %q", test.serviceName)
@@ -569,7 +571,7 @@ func TestConsul_serviceID(t *testing.T) {
569571
if test.valid && err != nil {
570572
t.Fatalf("expected Consul to initialize: %v", err)
571573
}
572-
if err := be.Run(shutdownCh, &sync.WaitGroup{}); err != nil {
574+
if err := be.Run(shutdownCh, &sync.WaitGroup{}, ""); err != nil {
573575
t.Fatal(err)
574576
}
575577

serviceregistration/kubernetes/client/client.go

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,33 @@ var (
2929
ErrNotInCluster = errors.New("unable to load in-cluster configuration, KUBERNETES_SERVICE_HOST and KUBERNETES_SERVICE_PORT must be defined")
3030
)
3131

32+
// Client is a minimal Kubernetes client. We rolled our own because the existing
33+
// Kubernetes client-go library available externally has a high number of dependencies
34+
// and we thought it wasn't worth it for only two API calls. If at some point they break
35+
// the client into smaller modules, or if we add quite a few methods to this client, it may
36+
// be worthwhile to revisit that decision.
37+
type Client struct {
38+
logger hclog.Logger
39+
config *Config
40+
stopCh chan struct{}
41+
}
42+
3243
// New instantiates a Client. The stopCh is used for exiting retry loops
3344
// when closed.
34-
func New(logger hclog.Logger, stopCh <-chan struct{}) (*Client, error) {
45+
func New(logger hclog.Logger) (*Client, error) {
3546
config, err := inClusterConfig()
3647
if err != nil {
3748
return nil, err
3849
}
3950
return &Client{
4051
logger: logger,
4152
config: config,
42-
stopCh: stopCh,
53+
stopCh: make(chan struct{}),
4354
}, nil
4455
}
4556

46-
// Client is a minimal Kubernetes client. We rolled our own because the existing
47-
// Kubernetes client-go library available externally has a high number of dependencies
48-
// and we thought it wasn't worth it for only two API calls. If at some point they break
49-
// the client into smaller modules, or if we add quite a few methods to this client, it may
50-
// be worthwhile to revisit that decision.
51-
type Client struct {
52-
logger hclog.Logger
53-
config *Config
54-
stopCh <-chan struct{}
57+
func (c *Client) Shutdown() {
58+
close(c.stopCh)
5559
}
5660

5761
// GetPod gets a pod from the Kubernetes API.
@@ -132,10 +136,13 @@ func (c *Client) do(req *http.Request, ptrToReturnObj interface{}) error {
132136
// a stop from our stopChan. This allows us to exit from our retry
133137
// loop during a shutdown, rather than hanging.
134138
ctx, cancelFunc := context.WithCancel(context.Background())
135-
go func(stopCh <-chan struct{}) {
136-
<-stopCh
137-
cancelFunc()
138-
}(c.stopCh)
139+
go func() {
140+
select {
141+
case <-ctx.Done():
142+
case <-c.stopCh:
143+
cancelFunc()
144+
}
145+
}()
139146
retryableReq.WithContext(ctx)
140147

141148
retryableReq.Header.Set("Authorization", "Bearer "+c.config.BearerToken)

serviceregistration/kubernetes/client/client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func TestClient(t *testing.T) {
2222
t.Fatal(err)
2323
}
2424

25-
client, err := New(hclog.Default(), make(chan struct{}))
25+
client, err := New(hclog.Default())
2626
if err != nil {
2727
t.Fatal(err)
2828
}

serviceregistration/kubernetes/client/cmd/kubeclient/main.go

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ import (
2121
"encoding/json"
2222
"flag"
2323
"fmt"
24+
"os"
25+
"os/signal"
2426
"strings"
27+
"syscall"
2528

2629
"github.com/hashicorp/go-hclog"
2730
"github.com/hashicorp/vault/serviceregistration/kubernetes/client"
@@ -42,39 +45,64 @@ func init() {
4245
func main() {
4346
flag.Parse()
4447

45-
c, err := client.New(hclog.Default(), make(chan struct{}))
48+
c, err := client.New(hclog.Default())
4649
if err != nil {
4750
panic(err)
4851
}
4952

50-
switch callToMake {
51-
case "get-pod":
52-
pod, err := c.GetPod(namespace, podName)
53-
if err != nil {
54-
panic(err)
55-
}
56-
b, _ := json.Marshal(pod)
57-
fmt.Printf("pod: %s\n", b)
58-
return
59-
case "patch-pod":
60-
patchPairs := strings.Split(patchesToAdd, ",")
61-
var patches []*client.Patch
62-
for _, patchPair := range patchPairs {
63-
fields := strings.Split(patchPair, ":")
64-
if len(fields) != 2 {
65-
panic(fmt.Errorf("unable to split %s from selectors provided of %s", fields, patchesToAdd))
53+
reqCh := make(chan struct{})
54+
shutdownCh := makeShutdownCh()
55+
56+
go func() {
57+
defer close(reqCh)
58+
59+
switch callToMake {
60+
case "get-pod":
61+
pod, err := c.GetPod(namespace, podName)
62+
if err != nil {
63+
panic(err)
6664
}
67-
patches = append(patches, &client.Patch{
68-
Operation: client.Replace,
69-
Path: fields[0],
70-
Value: fields[1],
71-
})
72-
}
73-
if err := c.PatchPod(namespace, podName, patches...); err != nil {
74-
panic(err)
65+
b, _ := json.Marshal(pod)
66+
fmt.Printf("pod: %s\n", b)
67+
return
68+
case "patch-pod":
69+
patchPairs := strings.Split(patchesToAdd, ",")
70+
var patches []*client.Patch
71+
for _, patchPair := range patchPairs {
72+
fields := strings.Split(patchPair, ":")
73+
if len(fields) != 2 {
74+
panic(fmt.Errorf("unable to split %s from selectors provided of %s", fields, patchesToAdd))
75+
}
76+
patches = append(patches, &client.Patch{
77+
Operation: client.Replace,
78+
Path: fields[0],
79+
Value: fields[1],
80+
})
81+
}
82+
if err := c.PatchPod(namespace, podName, patches...); err != nil {
83+
panic(err)
84+
}
85+
return
86+
default:
87+
panic(fmt.Errorf(`unsupported call provided: %q`, callToMake))
7588
}
76-
return
77-
default:
78-
panic(fmt.Errorf(`unsupported call provided: %q`, callToMake))
89+
}()
90+
91+
select {
92+
case <-shutdownCh:
93+
fmt.Println("Interrupt received, exiting...")
94+
case <-reqCh:
7995
}
8096
}
97+
98+
func makeShutdownCh() chan struct{} {
99+
resultCh := make(chan struct{})
100+
101+
shutdownCh := make(chan os.Signal, 4)
102+
signal.Notify(shutdownCh, os.Interrupt, syscall.SIGTERM)
103+
go func() {
104+
<-shutdownCh
105+
close(resultCh)
106+
}()
107+
return resultCh
108+
}

0 commit comments

Comments
 (0)