Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions cmd/operator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ import (
"os"
"time"

eventingv1alpha1 "github.com/kedacore/keda/v2/apis/eventing/v1alpha1"
kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1"
eventingcontrollers "github.com/kedacore/keda/v2/controllers/eventing"
kedacontrollers "github.com/kedacore/keda/v2/controllers/keda"
"github.com/kedacore/keda/v2/pkg/certificates"
"github.com/kedacore/keda/v2/pkg/eventemitter"
"github.com/kedacore/keda/v2/pkg/k8s"
"github.com/kedacore/keda/v2/pkg/metricscollector"
"github.com/kedacore/keda/v2/pkg/metricsservice"
"github.com/kedacore/keda/v2/pkg/scalers/authentication"
"github.com/kedacore/keda/v2/pkg/scalers/connectionpool"
"github.com/kedacore/keda/v2/pkg/scaling"
kedautil "github.com/kedacore/keda/v2/pkg/util"
"github.com/spf13/pflag"
apimachineryruntime "k8s.io/apimachinery/pkg/runtime"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
Expand All @@ -36,19 +49,6 @@ import (
"sigs.k8s.io/controller-runtime/pkg/log/zap"
"sigs.k8s.io/controller-runtime/pkg/metrics/server"
"sigs.k8s.io/controller-runtime/pkg/webhook"

eventingv1alpha1 "github.com/kedacore/keda/v2/apis/eventing/v1alpha1"
kedav1alpha1 "github.com/kedacore/keda/v2/apis/keda/v1alpha1"
eventingcontrollers "github.com/kedacore/keda/v2/controllers/eventing"
kedacontrollers "github.com/kedacore/keda/v2/controllers/keda"
"github.com/kedacore/keda/v2/pkg/certificates"
"github.com/kedacore/keda/v2/pkg/eventemitter"
"github.com/kedacore/keda/v2/pkg/k8s"
"github.com/kedacore/keda/v2/pkg/metricscollector"
"github.com/kedacore/keda/v2/pkg/metricsservice"
"github.com/kedacore/keda/v2/pkg/scalers/authentication"
"github.com/kedacore/keda/v2/pkg/scaling"
kedautil "github.com/kedacore/keda/v2/pkg/util"
//+kubebuilder:scaffold:imports
)

Expand Down Expand Up @@ -87,6 +87,7 @@ func main() {
var validatingWebhookName string
var caDirs []string
var enableWebhookPatching bool
var connectionPoolConfigPath string
pflag.BoolVar(&enablePrometheusMetrics, "enable-prometheus-metrics", true, "Enable the prometheus metric of keda-operator.")
pflag.BoolVar(&enableOpenTelemetryMetrics, "enable-opentelemetry-metrics", false, "Enable the opentelemetry metric of keda-operator.")
pflag.StringVar(&metricsAddr, "metrics-bind-address", ":8080", "The address the prometheus metric endpoint binds to.")
Expand All @@ -110,6 +111,7 @@ func main() {
pflag.StringVar(&validatingWebhookName, "validating-webhook-name", "keda-admission", "ValidatingWebhookConfiguration name. Defaults to keda-admission")
pflag.StringArrayVar(&caDirs, "ca-dir", []string{"/custom/ca"}, "Directory with CA certificates for scalers to authenticate TLS connections. Can be specified multiple times. Defaults to /custom/ca")
pflag.BoolVar(&enableWebhookPatching, "enable-webhook-patching", true, "Enable patching of webhook resources. Defaults to true.")
pflag.StringVar(&connectionPoolConfigPath, "pool-config-path", "", "Path to the global KEDA pool configuration file (optional)")
opts := zap.Options{}
opts.BindFlags(flag.CommandLine)
pflag.CommandLine.AddGoFlagSet(flag.CommandLine)
Expand Down Expand Up @@ -157,7 +159,12 @@ func main() {
metricsAddr = "0"
}
metricscollector.NewMetricsCollectors(enablePrometheusMetrics, enableOpenTelemetryMetrics)

if connectionPoolConfigPath != "" {
setupLog.Info("Initializing global pool configuration", "path", connectionPoolConfigPath)
connectionpool.InitGlobalPoolConfig(ctx, connectionPoolConfigPath)
} else {
setupLog.Info("No pool configuration file found, continuing with defaults")
}
mgr, err := ctrl.NewManager(cfg, ctrl.Options{
Scheme: scheme,
Metrics: server.Options{
Expand Down
97 changes: 97 additions & 0 deletions pkg/scalers/connectionpool/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package connectionpool

import (
"context"
"fmt"
"os"
"sync"

"github.com/fsnotify/fsnotify"
"gopkg.in/yaml.v3"
"sigs.k8s.io/controller-runtime/pkg/log"
)

var (
globalOverrides sync.Map
configPath string
logger = log.Log.WithName("connectionpool")
)

// InitGlobalPoolConfig loads the YAML config and starts a watcher for live reloads.
func InitGlobalPoolConfig(ctx context.Context, path string) {
configPath = path
loadConfig()
go startConfigWatcher(ctx)
}

// loadConfig parses the YAML and updates globalOverrides safely.
func loadConfig() {
data, err := os.ReadFile(configPath)
if err != nil {
logger.V(1).Info("No pool config found;", "path", configPath, "err", err)
clearGlobalPoolOverride()
return
}

var parsed map[string]string
if err := yaml.Unmarshal(data, &parsed); err != nil {
logger.Error(err, "Invalid pool config format", "path", configPath)
return
}

// clear existing map before writing new values
clearGlobalPoolOverride()

// store new values
for key, val := range parsed {
globalOverrides.Store(key, val)
}

logger.Info("Loaded global pool configuration", "entries", len(parsed))
}
func clearGlobalPoolOverride() {
globalOverrides.Range(func(key, _ any) bool {
globalOverrides.Delete(key)
return true
})
}

// startConfigWatcher watches for ConfigMap updates and reloads on file change.
func startConfigWatcher(ctx context.Context) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
logger.Error(err, "Failed to start global connection pool configuration file watcher")
return
}
defer watcher.Close()

_ = watcher.Add(configPath)
logger.Info("Started watching global connection pool configuration", "path", configPath)

for {
select {
case event := <-watcher.Events:
if event.Op&(fsnotify.Write|fsnotify.Create) != 0 {
logger.Info("Detected pool config change; reloading")
loadConfig()
}
case err := <-watcher.Errors:
if err != nil {
logger.Error(err, "Watcher error")
}
case <-ctx.Done():
logger.Info("Stopping pool config watcher")
return
}
}
}

// LookupConfigValue returns config for a scaler/resource identifier.
// Keys are structured for eg. as <scaler>.<identifier>, e.g., "postgres.dbserver.db".
func LookupConfigValue(scalerType, identifier string) string {
key := fmt.Sprintf("%s.%s", scalerType, identifier)
if val, ok := globalOverrides.Load(key); ok {
return val.(string)
}
return ""
}
129 changes: 129 additions & 0 deletions pkg/scalers/connectionpool/connectionpool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package connectionpool

import (
"context"
"errors"
"testing"

"go.uber.org/atomic"
)

// mockPool simulates a ResourcePool for testing purposes.
type mockPool struct {
closed atomic.Int32
}

func (m *mockPool) close() {
m.closed.Store(1)
}

func isClosed(p ResourcePool) bool {
if m, ok := p.(*mockPool); ok {
return m.closed.Load() == 1
}
return false
}

func mockCreateFn(_ context.Context) func() (ResourcePool, error) {
return func() (ResourcePool, error) {
return &mockPool{}, nil
}
}

func TestGetOrCreate_ReusesPool(t *testing.T) {
ctx := context.Background()
poolKey := "postgres.db1.analytics"

first, err := GetOrCreate(poolKey, mockCreateFn(ctx))
if err != nil {
t.Fatalf("failed to create pool: %v", err)
}

second, err := GetOrCreate(poolKey, mockCreateFn(ctx))
if err != nil {
t.Fatalf("failed to reuse pool: %v", err)
}

if first != second {
t.Fatalf("expected same pool instance, got different ones")
}

Release(poolKey)
Release(poolKey)
}

func TestGetOrCreate_DifferentKeys(t *testing.T) {
ctx := context.Background()

pool1, _ := GetOrCreate("postgres.db1.analytics", mockCreateFn(ctx))
pool2, _ := GetOrCreate("postgres.db2.reporting", mockCreateFn(ctx))

if pool1 == pool2 {
t.Fatalf("expected different pools for different keys")
}

Release("postgres.db1.analytics")
Release("postgres.db2.reporting")
}

func TestRelease_RefCount(t *testing.T) {
ctx := context.Background()
key := "postgres.db1.analytics"

p1, _ := GetOrCreate(key, mockCreateFn(ctx))
p2, _ := GetOrCreate(key, mockCreateFn(ctx))

if p1 != p2 {
t.Fatalf("expected pool reuse")
}

Release(key) // decrements
if isClosed(p1) {
t.Fatalf("pool should not be closed with remaining references")
}
Release(key) // closes and deletes

if _, found := poolMap.Load(key); found {
t.Fatalf("expected pool to be removed after final release")
}
}

func TestConcurrentAccess(t *testing.T) {
key := "postgres.db1.analytics"

done := make(chan bool)
for i := 0; i < 10; i++ {
go func() {
for j := 0; j < 100; j++ {
p, err := GetOrCreate(key, mockCreateFn(context.Background()))
if err != nil {
t.Errorf("get or create: %v", err)
}
if p == nil {
t.Errorf("nil pool returned")
}
Release(key)
}
done <- true
}()
}

for i := 0; i < 10; i++ {
<-done
}
}

func TestInvalidConnectionHandledGracefully(t *testing.T) {
key := "postgres.invalid"
var ErrPoolCreationFailed = errors.New("pool creation failed")
createFn := func() (ResourcePool, error) {
return nil, ErrPoolCreationFailed
}

_, err := GetOrCreate(key, createFn)
if err == nil {
t.Log("expected error during pool creation")
}

Release(key)
}
76 changes: 76 additions & 0 deletions pkg/scalers/connectionpool/manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package connectionpool

import (
"sync"

"go.uber.org/atomic"
)

type ResourcePool interface {
close()
}

type poolEntry struct {
pool ResourcePool
ref atomic.Int32
}

var poolMap sync.Map

func GetOrCreate(poolKey string, createFn func() (ResourcePool, error)) (ResourcePool, error) {
if val, ok := poolMap.Load(poolKey); ok {
entry := val.(*poolEntry)
entry.ref.Inc()
logger.V(1).Info("Reusing existing pool", "poolKey", poolKey, "refCount", entry.ref.Load())
return entry.pool, nil
}

logger.Info("Creating new pool", "poolKey", poolKey)
newPool, err := createFn()
if err != nil {
logger.Error(err, "Failed to create new pool", "poolKey", poolKey)
return nil, err
}

e := &poolEntry{pool: newPool}
e.ref.Store(1)

actual, loaded := poolMap.LoadOrStore(poolKey, e)
if loaded {
logger.Info("Duplicate creation detected, closing redundant pool", "poolKey", poolKey)
newPool.close()
old := actual.(*poolEntry)
old.ref.Inc()
logger.V(1).Info("Reusing existing pool after race", "poolKey", poolKey, "refCount", old.ref.Load())
return old.pool, nil
}

logger.Info("Pool created successfully", "poolKey", poolKey)
return newPool, nil
}

func Release(poolKey string) {
val, ok := poolMap.Load(poolKey)
if !ok {
logger.V(1).Info("Attempted to release non-existent pool", "poolKey", poolKey)
return
}
entry := val.(*poolEntry)
if entry.ref.Dec() <= 0 {
logger.Info("Closing pool, no active references", "poolKey", poolKey)
entry.pool.close()
poolMap.Delete(poolKey)
} else {
logger.V(1).Info("Released pool reference", "poolKey", poolKey, "refCount", entry.ref.Load())
}
}

func CloseAll() {
poolMap.Range(func(key, val any) bool {
entry := val.(*poolEntry)
logger.Info("Closing pool", "poolKey", key)
entry.pool.close()
poolMap.Delete(key)
return true
})
}
Loading
Loading