Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,30 @@ type StreamingCredentialsProvider interface {
// Subscribe subscribes to the credentials provider for updates.
// It returns the current credentials, a cancel function to unsubscribe from the provider,
// and an error if any.
//
// Implementations MUST be idempotent with respect to listener identity:
// subscribing the same listener value more than once must not produce
// duplicate notifications and must not create multiple independent
// subscriptions that each need to be cancelled separately. Every
// UnsubscribeFunc returned for a given listener must cancel that
// listener's subscription; calling any one of them must be sufficient to
// stop updates to that listener, and calling subsequent ones must be a
// safe no-op. Callers (including go-redis internals) may retain only
// the most recently returned UnsubscribeFunc and rely on it to fully
// unsubscribe the listener.
//
// TODO(ndyakov): Should we add context to the Subscribe method?
Subscribe(listener CredentialsListener) (Credentials, UnsubscribeFunc, error)
}

// UnsubscribeFunc is a function that is used to cancel the subscription to the credentials provider.
// It is used to unsubscribe from the provider when the credentials are no longer needed.
//
// Per the StreamingCredentialsProvider.Subscribe contract, if the same
// listener is subscribed multiple times, every UnsubscribeFunc returned for
// that listener must fully unsubscribe it on first invocation, and
// subsequent invocations (from any of the equivalent UnsubscribeFuncs) must
// be a safe no-op.
type UnsubscribeFunc func() error

// CredentialsListener is an interface that defines the methods for a credentials listener.
Expand Down
35 changes: 35 additions & 0 deletions internal/pool/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,41 @@ func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Durat
}
}

// SetOnClose installs fn as the callback invoked exactly once when this
// connection is closed (via Conn.Close).
//
// IMPORTANT: SetOnClose OVERWRITES any previously installed callback — it
// does not compose, chain, or deduplicate. A Conn has room for a single
// onClose hook by design, because its lifecycle is bounded (a Conn is
// created, optionally re-initialized on its own net.Conn, and then closed
// once) and the pool's OnRemove hooks handle any registry-level cleanup
// that must survive the net.Conn being swapped.
//
// This has a subtle implication for per-connection subscriptions such as
// the unsubscribe function returned by StreamingCredentialsProvider
// (e.g. EntraID token rotation): if SetOnClose is called twice on the
// same Conn with DIFFERENT unsubscribe closures — for example because
// initConn ran a second time and obtained a fresh Subscribe() —
// the previous unsubscribe is dropped and will NEVER run, leaking a
// subscription on the provider. Callers must therefore ensure either:
//
// - the provider's Subscribe is idempotent for the same listener (the
// streaming credentials Manager deduplicates listeners by connection
// id, so re-Subscribe returns an equivalent unsubscribe), OR
// - the previous callback has already been invoked before SetOnClose is
// called again.
//
// Design note: unlike the client-level onCloseHooks registry (see
// redis.baseClient), there is intentionally NO named-hook dedup or
// multi-callback support on Conn. This is a deliberate trade-off to keep
// the Conn object slim — a pool can hold thousands of Conn values and
// each one is a hot allocation, so paying for a sync.Mutex plus a
// map[string]func() error per connection to support a feature that would
// only be used by at most one subsystem today (streaming credentials) is
// not worth the per-connection memory and allocation cost. For a single
// Conn there is at most one meaningful close callback at any point in
// time, and a richer registry here would not even solve the "stale
// closure" hazard described above.
func (cn *Conn) SetOnClose(fn func() error) {
cn.onClose = fn
}
Expand Down
286 changes: 284 additions & 2 deletions internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"
"time"

"github.com/redis/go-redis/v9/auth"
"github.com/redis/go-redis/v9/internal/pool"
"github.com/redis/go-redis/v9/internal/proto"

Expand Down Expand Up @@ -460,6 +461,10 @@ func (ct *testCounter) expect(values map[string]int) {
}
}

// testOnCloseHookID is the id used by the ring-shard cleanup tests when
// registering a close hook against the internal onCloseHooks registry.
const testOnCloseHookID = "test-close-counter"

func TestRingShardsCleanup(t *testing.T) {
const (
ringShard1Name = "ringShardOne"
Expand All @@ -479,7 +484,7 @@ func TestRingShardsCleanup(t *testing.T) {
},
NewClient: func(opt *Options) *Client {
c := NewClient(opt)
c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
c.baseClient.onClose.register(testOnCloseHookID, func() error {
closeCounter.increment(opt.Addr)
return nil
})
Expand Down Expand Up @@ -528,7 +533,7 @@ func TestRingShardsCleanup(t *testing.T) {
}
createCounter.increment(opt.Addr)
c := NewClient(opt)
c.baseClient.onClose = c.baseClient.wrappedOnClose(func() error {
c.baseClient.onClose.register(testOnCloseHookID, func() error {
closeCounter.increment(opt.Addr)
return nil
})
Expand Down Expand Up @@ -685,3 +690,280 @@ var _ = Describe("isLoopback", func() {
Entry("partial docker internal", "docker.internal", false),
)
})


// TestOnCloseHooks_RunInRegistrationOrder verifies that hooks registered under
// distinct ids are all invoked on run() in the order they were registered.
func TestOnCloseHooks_RunInRegistrationOrder(t *testing.T) {
h := &onCloseHooks{}
var calls []string

h.register("a", func() error { calls = append(calls, "a"); return nil })
h.register("b", func() error { calls = append(calls, "b"); return nil })
h.register("c", func() error { calls = append(calls, "c"); return nil })

if err := h.run(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := []string{"a", "b", "c"}
if !reflect.DeepEqual(calls, want) {
t.Fatalf("run order = %v, want %v", calls, want)
}
}

// TestOnCloseHooks_RegisterSameIDReplaces is the regression test for issue
// #3772. Registering the same id repeatedly must replace the existing
// callback rather than chain onto it, so the registry stays bounded even
// under storm-like re-registration (the exact scenario that previously leaked
// when initConn re-wrapped c.onClose on every connection init).
func TestOnCloseHooks_RegisterSameIDReplaces(t *testing.T) {
h := &onCloseHooks{}
const id = "same-id"
const iterations = 10_000

var lastSeen int32
for i := 0; i < iterations; i++ {
i := int32(i)
h.register(id, func() error { atomic.StoreInt32(&lastSeen, i); return nil })
}

if got := len(h.order); got != 1 {
t.Fatalf("order length after %d re-registrations = %d, want 1", iterations, got)
}
if got := len(h.hooks); got != 1 {
t.Fatalf("hooks map size after %d re-registrations = %d, want 1", iterations, got)
}

if err := h.run(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got := atomic.LoadInt32(&lastSeen); got != iterations-1 {
t.Fatalf("last-registered callback not invoked: lastSeen = %d, want %d", got, iterations-1)
}
}

// TestOnCloseHooks_DistinctIDsCoexist guarantees the dedup behavior does not
// discard hooks from other callers: registering new ids must never drop
// previously registered ids.
func TestOnCloseHooks_DistinctIDsCoexist(t *testing.T) {
h := &onCloseHooks{}
var aCount, bCount int32

h.register("a", func() error { atomic.AddInt32(&aCount, 1); return nil })
h.register("b", func() error { atomic.AddInt32(&bCount, 1); return nil })
// Re-registering "a" must not drop "b".
h.register("a", func() error { atomic.AddInt32(&aCount, 1); return nil })

if err := h.run(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if a, b := atomic.LoadInt32(&aCount), atomic.LoadInt32(&bCount); a != 1 || b != 1 {
t.Fatalf("call counts a=%d b=%d, want a=1 b=1", a, b)
}
}

// TestOnCloseHooks_Unregister verifies that unregister removes a hook and
// that running after unregister does not invoke it.
func TestOnCloseHooks_Unregister(t *testing.T) {
h := &onCloseHooks{}
var aCalled, bCalled bool

h.register("a", func() error { aCalled = true; return nil })
h.register("b", func() error { bCalled = true; return nil })
h.unregister("a")
h.unregister("missing") // no-op must not panic

if err := h.run(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if aCalled {
t.Fatal("unregistered hook was invoked")
}
if !bCalled {
t.Fatal("remaining hook was not invoked")
}
if got := len(h.order); got != 1 {
t.Fatalf("order length = %d, want 1", got)
}
}


// TestOnCloseHooks_AllRunOnError confirms every hook is invoked even if an
// earlier one returns an error, and that the first error is returned.
func TestOnCloseHooks_AllRunOnError(t *testing.T) {
h := &onCloseHooks{}
var called [3]bool
errFirst := fmt.Errorf("first")
errSecond := fmt.Errorf("second")

h.register("a", func() error { called[0] = true; return errFirst })
h.register("b", func() error { called[1] = true; return errSecond })
h.register("c", func() error { called[2] = true; return nil })

err := h.run()
if err != errFirst {
t.Fatalf("run() err = %v, want %v", err, errFirst)
}
for i, c := range called {
if !c {
t.Fatalf("hook %d was not invoked", i)
}
}
}

// TestOnCloseHooks_NilReceiver ensures run() on a nil registry is a safe
// no-op. baseClient embedded in Conn/Tx does initialize the registry, but
// defensive nil-safety lets future constructors add the field without
// breaking Close().
func TestOnCloseHooks_NilReceiver(t *testing.T) {
var h *onCloseHooks
if err := h.run(); err != nil {
t.Fatalf("run() on nil = %v, want nil", err)
}
}

// TestOnCloseHooks_ConcurrentRegisterSameID hammers the registry with many
// goroutines re-registering under the same id. The registry must remain
// bounded and the surviving callback must still be invoked exactly once.
func TestOnCloseHooks_ConcurrentRegisterSameID(t *testing.T) {
h := &onCloseHooks{}
const id = "hot"
const goroutines = 64
const perG = 1_000

var wg sync.WaitGroup
wg.Add(goroutines)
for g := 0; g < goroutines; g++ {
go func() {
defer wg.Done()
for i := 0; i < perG; i++ {
h.register(id, func() error { return nil })
}
}()
}
wg.Wait()

if got := len(h.order); got != 1 {
t.Fatalf("order length after concurrent storm = %d, want 1", got)
}
}


// entraidLikeProvider mimics the exact semantics of
// github.com/redis/go-redis-entraid's StreamingCredentialsProvider relevant
// to issue #3772: it deduplicates subscriptions by listener pointer
// identity, and every call to Subscribe returns a FRESH UnsubscribeFunc
// closure that removes the listener by pointer match from the shared
// listeners slice. Two unsubs obtained for the same listener are therefore
// equivalent: the first one called removes the entry, any subsequent call
// is a safe no-op.
type entraidLikeProvider struct {
mu sync.Mutex
listeners []auth.CredentialsListener
subscribeN int32
unsubCalls int32
}

func (p *entraidLikeProvider) Subscribe(listener auth.CredentialsListener) (auth.Credentials, auth.UnsubscribeFunc, error) {
atomic.AddInt32(&p.subscribeN, 1)

p.mu.Lock()
already := false
for _, l := range p.listeners {
if l == listener {
already = true
break
}
}
if !already {
p.listeners = append(p.listeners, listener)
}
p.mu.Unlock()

unsub := func() error {
atomic.AddInt32(&p.unsubCalls, 1)
p.mu.Lock()
defer p.mu.Unlock()
for i, l := range p.listeners {
if l == listener {
p.listeners = append(p.listeners[:i], p.listeners[i+1:]...)
return nil
}
}
return nil
}
return auth.NewBasicCredentials("u", "p"), unsub, nil
}

func (p *entraidLikeProvider) listenerCount() int {
p.mu.Lock()
defer p.mu.Unlock()
return len(p.listeners)
}

// stubCredentialsListener is a minimal auth.CredentialsListener used only
// for pointer-identity in the entraid-mimicking test.
type stubCredentialsListener struct{}

func (*stubCredentialsListener) OnNext(auth.Credentials) {}
func (*stubCredentialsListener) OnError(error) {}

// TestInitConn_EntraidLike_NoLeakAcrossReinits is the targeted regression
// test for issue #3772 against the real-world StreamingCredentialsProvider
// behavior implemented by go-redis-entraid. It simulates N re-initializations
// on the same logical connection (i.e. the same CredentialsListener pointer,
// which is what streaming.Manager.Listener returns from its per-connId cache),
// replacing cn.SetOnClose with each new unsubscribe closure.
//
// Invariants the fix must uphold:
// 1. Subscribe dedups: the provider's listener list stays at size 1.
// 2. Calling only the MOST RECENT unsub fully removes the listener.
// 3. All prior (orphaned) unsubs are safe no-ops after that.
// 4. No registration remains on the provider after close.
func TestInitConn_EntraidLike_NoLeakAcrossReinits(t *testing.T) {
const reinits = 1000

provider := &entraidLikeProvider{}
listener := &stubCredentialsListener{}

var latestUnsub auth.UnsubscribeFunc
orphaned := make([]auth.UnsubscribeFunc, 0, reinits-1)

for i := 0; i < reinits; i++ {
_, unsub, err := provider.Subscribe(listener)
if err != nil {
t.Fatalf("Subscribe #%d: %v", i, err)
}
if latestUnsub != nil {
// Mirror the pool.Conn.SetOnClose behavior: the previous unsub
// is dropped on the floor, only the latest one is retained.
orphaned = append(orphaned, latestUnsub)
}
latestUnsub = unsub
}

if got := provider.listenerCount(); got != 1 {
t.Fatalf("after %d Subscribes with same listener, listener count = %d, want 1", reinits, got)
}
if got := atomic.LoadInt32(&provider.subscribeN); got != int32(reinits) {
t.Fatalf("Subscribe call count = %d, want %d", got, reinits)
}

// Only the latest unsub is invoked, matching the post-fix cn.onClose.
if err := latestUnsub(); err != nil {
t.Fatalf("latest unsub returned error: %v", err)
}
if got := provider.listenerCount(); got != 0 {
t.Fatalf("listener count after latest unsub = %d, want 0", got)
}

// Every orphaned unsub must be a safe no-op (contract in auth.UnsubscribeFunc).
for i, u := range orphaned {
if err := u(); err != nil {
t.Fatalf("orphaned unsub #%d returned error: %v", i, err)
}
}
if got := provider.listenerCount(); got != 0 {
t.Fatalf("listener count after orphaned unsubs = %d, want 0", got)
}
}
Loading
Loading