diff --git a/consul/config.go b/consul/config.go index e482cf0..d504f4c 100644 --- a/consul/config.go +++ b/consul/config.go @@ -15,7 +15,7 @@ type Config struct { } type Upstream struct { - Service string + Name string LocalBindAddress string LocalBindPort int Protocol string diff --git a/consul/watcher.go b/consul/watcher.go index f30958e..6211007 100644 --- a/consul/watcher.go +++ b/consul/watcher.go @@ -2,6 +2,8 @@ package consul import ( "crypto/x509" + "fmt" + "reflect" "sync" "time" @@ -13,13 +15,14 @@ const ( defaultDownstreamBindAddr = "0.0.0.0" defaultUpstreamBindAddr = "127.0.0.1" - errorWaitTime = 5 * time.Second + errorWaitTime = 5 * time.Second + preparedQueryPollInterval = 30 * time.Second ) type upstream struct { LocalBindAddress string LocalBindPort int - Service string + Name string Datacenter string Protocol string Nodes []*api.ServiceEntry @@ -138,12 +141,18 @@ func (w *Watcher) handleProxyChange(first bool, srv *api.AgentService) { if srv.Proxy != nil { for _, up := range srv.Proxy.Upstreams { - keep[up.DestinationName] = true + name := fmt.Sprintf("%s_%s", up.DestinationType, up.DestinationName) + keep[name] = true w.lock.Lock() - _, ok := w.upstreams[up.DestinationName] + _, ok := w.upstreams[name] w.lock.Unlock() if !ok { - w.startUpstream(up) + switch up.DestinationType { + case api.UpstreamDestTypePreparedQuery: + w.startUpstreamPreparedQuery(up, name) + default: + w.startUpstreamService(up, name) + } } } } @@ -159,24 +168,22 @@ func (w *Watcher) handleProxyChange(first bool, srv *api.AgentService) { } } -func (w *Watcher) startUpstream(up api.Upstream) { +func (w *Watcher) startUpstreamService(up api.Upstream, name string) { w.log.Infof("consul: watching upstream for service %s", up.DestinationName) u := &upstream{ LocalBindAddress: up.LocalBindAddress, LocalBindPort: up.LocalBindPort, - Service: up.DestinationName, + Name: name, Datacenter: up.Datacenter, } - if up.Config["protocol"] != nil { - if p, ok := up.Config["protocol"].(string); ok { - u.Protocol = p - } + if p, ok := up.Config["protocol"].(string); ok { + u.Protocol = p } w.lock.Lock() - w.upstreams[up.DestinationName] = u + w.upstreams[name] = u w.lock.Unlock() go func() { @@ -209,6 +216,75 @@ func (w *Watcher) startUpstream(up api.Upstream) { }() } +func (w *Watcher) startUpstreamPreparedQuery(up api.Upstream, name string) { + w.log.Infof("consul: watching upstream for prepared_query %s", up.DestinationName) + + u := &upstream{ + LocalBindAddress: up.LocalBindAddress, + LocalBindPort: up.LocalBindPort, + Name: name, + Datacenter: up.Datacenter, + } + + if p, ok := up.Config["protocol"].(string); ok { + u.Protocol = p + } + + interval := preparedQueryPollInterval + if p, ok := up.Config["poll_interval"].(string); ok { + dur, err := time.ParseDuration(p) + if err != nil { + w.log.Errorf( + "consul: upstream %s %s: invalid poll interval %s: %s", + up.DestinationType, + up.DestinationName, + p, + err, + ) + return + } + interval = dur + } + + w.lock.Lock() + w.upstreams[name] = u + w.lock.Unlock() + + go func() { + var last []*api.ServiceEntry + for { + if u.done { + return + } + nodes, _, err := w.consul.PreparedQuery().Execute(up.DestinationName, &api.QueryOptions{ + Connect: true, + Datacenter: up.Datacenter, + WaitTime: 10 * time.Minute, + }) + if err != nil { + w.log.Errorf("consul: error fetching service definition for service %s: %s", up.DestinationName, err) + time.Sleep(errorWaitTime) + continue + } + + nodesP := []*api.ServiceEntry{} + for i := range nodes.Nodes { + nodesP = append(nodesP, &nodes.Nodes[i]) + } + + if !reflect.DeepEqual(last, nodesP) { + w.lock.Lock() + u.Nodes = nodesP + w.lock.Unlock() + w.notifyChanged() + last = nodesP + } + + time.Sleep(interval) + } + }() +} + func (w *Watcher) removeUpstream(name string) { w.log.Infof("consul: removing upstream for service %s", name) @@ -366,7 +442,7 @@ func (w *Watcher) genCfg() Config { for _, up := range w.upstreams { upstream := Upstream{ - Service: up.Service, + Name: up.Name, LocalBindAddress: up.LocalBindAddress, LocalBindPort: up.LocalBindPort, Protocol: up.Protocol, diff --git a/haproxy/state/snapshot_test.go b/haproxy/state/snapshot_test.go index b397232..636910a 100644 --- a/haproxy/state/snapshot_test.go +++ b/haproxy/state/snapshot_test.go @@ -21,7 +21,7 @@ func GetTestConsulConfig() consul.Config { }, Upstreams: []consul.Upstream{ consul.Upstream{ - Service: "service_1", + Name: "service_1", LocalBindAddress: "127.0.0.1", LocalBindPort: 10000, Nodes: []consul.UpstreamNode{ diff --git a/haproxy/state/upstream.go b/haproxy/state/upstream.go index f8ee6bd..3509967 100644 --- a/haproxy/state/upstream.go +++ b/haproxy/state/upstream.go @@ -8,8 +8,8 @@ import ( ) func generateUpstream(opts Options, certStore CertificateStore, cfg consul.Upstream, oldState, newState State) (State, error) { - feName := fmt.Sprintf("front_%s", cfg.Service) - beName := fmt.Sprintf("back_%s", cfg.Service) + feName := fmt.Sprintf("front_%s", cfg.Name) + beName := fmt.Sprintf("back_%s", cfg.Name) feMode := models.FrontendModeHTTP beMode := models.BackendModeHTTP diff --git a/haproxy_test.go b/haproxy_test.go index 201b11f..522cc07 100644 --- a/haproxy_test.go +++ b/haproxy_test.go @@ -4,6 +4,7 @@ import ( "fmt" "io/ioutil" "testing" + "time" "net/http" @@ -13,7 +14,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestSetup(t *testing.T) { +func TestService(t *testing.T) { err := haproxy_cmd.CheckEnvironment(haproxy_cmd.DefaultDataplaneBin, haproxy_cmd.DefaultHAProxyBin) if err != nil { t.Skipf("CANNOT Run test because of missing requirement: %s", err.Error()) @@ -64,3 +65,72 @@ func TestSetup(t *testing.T) { res.Body.Close() require.Equal(t, "hello connect", string(body)) } + +func TestPreparedQuery(t *testing.T) { + err := haproxy_cmd.CheckEnvironment(haproxy_cmd.DefaultDataplaneBin, haproxy_cmd.DefaultHAProxyBin) + if err != nil { + t.Skipf("CANNOT Run test because of missing requirement: %s", err.Error()) + } + sd := lib.NewShutdown() + client := startAgent(t, sd) + defer func() { + sd.Shutdown("test end") + sd.Wait() + }() + + _, _, err = client.PreparedQuery().Create(&api.PreparedQueryDefinition{ + Name: "pq-", + Service: api.ServiceQuery{ + Service: "${match(1)}", + OnlyPassing: true, + }, + Template: api.QueryTemplate{ + Type: "name_prefix_match", + Regexp: "^pq-(.+)$", + }, + }, &api.WriteOptions{}) + require.NoError(t, err) + + csd, _, upstreamPorts := startConnectService(t, sd, client, &api.AgentServiceRegistration{ + Name: "source", + ID: "source-1", + + Connect: &api.AgentServiceConnect{ + SidecarService: &api.AgentServiceRegistration{ + Proxy: &api.AgentServiceConnectProxyConfig{ + Upstreams: []api.Upstream{ + api.Upstream{ + DestinationType: api.UpstreamDestTypePreparedQuery, + DestinationName: "pq-target", + Config: map[string]interface{}{ + "poll_interval": (100 * time.Millisecond).String(), + }, + }, + }, + }, + }, + }, + }) + + tsd, servicePort, _ := startConnectService(t, sd, client, &api.AgentServiceRegistration{ + Name: "target", + ID: "target-1", + + Connect: &api.AgentServiceConnect{ + SidecarService: &api.AgentServiceRegistration{ + Proxy: &api.AgentServiceConnectProxyConfig{}, + }, + }, + }) + + startServer(t, sd, servicePort, "hello connect prepared query") + wait(sd, csd, tsd) + res, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d", upstreamPorts["pq-target"])) + require.NoError(t, err) + require.Equal(t, 200, res.StatusCode) + + body, err := ioutil.ReadAll(res.Body) + require.NoError(t, err) + res.Body.Close() + require.Equal(t, "hello connect prepared query", string(body)) +} diff --git a/utils_test.go b/utils_test.go index 0591ecf..0b7803f 100644 --- a/utils_test.go +++ b/utils_test.go @@ -139,7 +139,7 @@ func startServer(t *testing.T, sd *lib.Shutdown, port int, response string) { sd.Add(1) go func() { http.Serve(lis, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.Write([]byte("hello connect")) + rw.Write([]byte(response)) })) }() go func() {