Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions maintnotifications/e2e/notification_injector.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,19 +523,37 @@ func formatSMigratingNotification(seqID int64, slots ...string) string {
}

func formatSMigratedNotification(seqID int64, endpoints ...string) string {
// New Format: ["SMIGRATED", SeqID, count, [endpoint1, endpoint2, ...]]
// Correct Format: ["SMIGRATED", SeqID, [[host:port, slots], [host:port, slots], ...]]
// RESP3 wire format:
// >3
// +SMIGRATED
// :SeqID
// *<num_entries>
// *2
// +<host:port>
// +<slots-or-ranges>
// Each endpoint is formatted as: "host:port slot1,slot2,range1-range2"
// Example: >4\r\n$9\r\nSMIGRATED\r\n:15\r\n:2\r\n*2\r\n$31\r\n127.0.0.1:6379 123,456,789-1000\r\n$30\r\n127.0.0.1:6380 124,457,300-500\r\n
parts := []string{">4\r\n"}
parts = append(parts, "$9\r\nSMIGRATED\r\n")
parts := []string{">3\r\n"}
parts = append(parts, "+SMIGRATED\r\n")
parts = append(parts, fmt.Sprintf(":%d\r\n", seqID))

count := len(endpoints)
parts = append(parts, fmt.Sprintf(":%d\r\n", count))
parts = append(parts, fmt.Sprintf("*%d\r\n", count))

for _, endpoint := range endpoints {
parts = append(parts, fmt.Sprintf("$%d\r\n%s\r\n", len(endpoint), endpoint))
// Split endpoint into host:port and slots
// endpoint format: "host:port slot1,slot2,range1-range2"
endpointParts := strings.SplitN(endpoint, " ", 2)
if len(endpointParts) != 2 {
continue
}
hostPort := endpointParts[0]
slots := endpointParts[1]

// Each entry is an array with 2 elements
parts = append(parts, "*2\r\n")
parts = append(parts, fmt.Sprintf("+%s\r\n", hostPort))
parts = append(parts, fmt.Sprintf("+%s\r\n", slots))
}

return strings.Join(parts, "")
Expand Down
72 changes: 37 additions & 35 deletions maintnotifications/push_notification_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,19 @@ func (snh *NotificationHandler) handleSMigrating(ctx context.Context, handlerCtx
// handleSMigrated processes SMIGRATED notifications.
// SMIGRATED indicates that a cluster slot has finished migrating to a different node.
// This is a cluster-level notification that triggers cluster state reload.
// Expected format: ["SMIGRATED", SeqID, count, [endpoint1, endpoint2, ...]]
// Each endpoint is formatted as: "host:port slot1,slot2,range1-range2"
// Expected format: ["SMIGRATED", SeqID, [[host:port, slots], [host:port, slots], ...]]
// RESP3 wire format:
// >3
// +SMIGRATED
// :SeqID
// *<num_entries>
// *2
// +<host:port>
// +<slots-or-ranges>
// Note: Multiple connections may receive the same notification, so we deduplicate by SeqID before triggering reload.
// but we still process the notification on each connection to clear the relaxed timeout.
func (snh *NotificationHandler) handleSMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
if len(notification) != 4 {
if len(notification) != 3 {
internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED", notification))
return ErrInvalidNotification
}
Expand All @@ -358,51 +365,46 @@ func (snh *NotificationHandler) handleSMigrated(ctx context.Context, handlerCtx
return ErrInvalidNotification
}

// Extract count (position 2)
count, ok := notification[2].(int64)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (count)", notification))
return ErrInvalidNotification
}

// Extract endpoints array (position 3)
endpointsArray, ok := notification[3].([]interface{})
// Extract endpoints array (position 2)
// Each entry is an array: [host:port, slots]
endpointsArray, ok := notification[2].([]interface{})
if !ok {
internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (endpoints)", notification))
return ErrInvalidNotification
}

if int64(len(endpointsArray)) != count {
internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (count mismatch)", notification))
return ErrInvalidNotification
}

// Parse endpoints
endpoints := make([]string, 0, count)
for _, ep := range endpointsArray {
if endpoint, ok := ep.(string); ok {
endpoints = append(endpoints, endpoint)
}
}

// Deduplicate by SeqID - multiple connections may receive the same notification
if snh.manager.MarkSMigratedSeqIDProcessed(seqID) {
// For logging and triggering reload, we use the first endpoint's host:port
// and collect all slot ranges from all endpoints
var hostPort string
var allSlotRanges []string

for _, endpoint := range endpoints {
// Parse endpoint: "host:port slot1,slot2,range1-range2"
parts := strings.SplitN(endpoint, " ", 2)
if len(parts) == 2 {
if hostPort == "" {
hostPort = parts[0]
}
// Split slots by comma
slots := strings.Split(parts[1], ",")
allSlotRanges = append(allSlotRanges, slots...)
for _, ep := range endpointsArray {
// Each endpoint is an array: [host:port, slots]
endpointParts, ok := ep.([]interface{})
if !ok || len(endpointParts) != 2 {
internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (endpoint format)", ep))
continue
}

// Extract host:port (element 0)
hostPortStr, ok := endpointParts[0].(string)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (host:port)", endpointParts[0]))
continue
}

// Extract slots (element 1)
slotsStr, ok := endpointParts[1].(string)
if !ok {
internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (slots)", endpointParts[1]))
continue
}

hostPort = hostPortStr
slotRanges := strings.Split(slotsStr, ",")
allSlotRanges = append(allSlotRanges, slotRanges...)
}

if internal.LogLevel.InfoOrAbove() {
Expand Down
74 changes: 74 additions & 0 deletions maintnotifications/push_notification_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package maintnotifications

import (
"context"
"testing"

"github.com/redis/go-redis/v9/push"
)

// TestHandleSMigrated_CorrectFormat tests that handleSMigrated correctly parses the correct format
func TestHandleSMigrated_CorrectFormat(t *testing.T) {
// Create a minimal manager for testing
config := DefaultConfig()
manager := &Manager{
config: config,
}

handler := &NotificationHandler{
manager: manager,
}

// Create notification in the correct format:
// ["SMIGRATED", SeqID, [[host:port, slots], [host:port, slots], ...]]
notification := []interface{}{
"SMIGRATED",
int64(12346),
[]interface{}{
[]interface{}{"127.0.0.1:6379", "123,456,789-1000"},
[]interface{}{"127.0.0.1:6380", "124,457,300-500"},
},
}

ctx := context.Background()
handlerCtx := push.NotificationHandlerContext{
Conn: nil, // No connection needed for this test
}

// This should not return an error
err := handler.handleSMigrated(ctx, handlerCtx, notification)
if err != nil {
t.Errorf("handleSMigrated failed with correct format: %v", err)
}
}

// TestHandleSMigrated_SingleEndpoint tests parsing with a single endpoint
func TestHandleSMigrated_SingleEndpoint(t *testing.T) {
config := DefaultConfig()
manager := &Manager{
config: config,
}

handler := &NotificationHandler{
manager: manager,
}

notification := []interface{}{
"SMIGRATED",
int64(100),
[]interface{}{
[]interface{}{"127.0.0.1:6380", "1000,2000-3000"},
},
}

ctx := context.Background()
handlerCtx := push.NotificationHandlerContext{
Conn: nil,
}

err := handler.handleSMigrated(ctx, handlerCtx, notification)
if err != nil {
t.Errorf("handleSMigrated failed with single endpoint: %v", err)
}
}

45 changes: 27 additions & 18 deletions osscluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ type ClusterOptions struct {
// ShardPicker is used to pick a shard when the request_policy is
// ReqDefault and the command has no keys.
ShardPicker routing.ShardPicker

// ClusterStateReloadInterval is the interval for reloading the cluster state.
ClusterStateReloadInterval time.Duration
}

func (opt *ClusterOptions) init() {
Expand Down Expand Up @@ -258,6 +261,10 @@ func (opt *ClusterOptions) init() {
if opt.ShardPicker == nil {
opt.ShardPicker = &routing.RoundRobinPicker{}
}

if opt.ClusterStateReloadInterval == 0 {
opt.ClusterStateReloadInterval = 30 * time.Second
}
}

// ParseClusterURL parses a URL into ClusterOptions that can be used to connect to Redis.
Expand Down Expand Up @@ -422,17 +429,17 @@ func (opt *ClusterOptions) clientOptions() *Options {

ContextTimeoutEnabled: opt.ContextTimeoutEnabled,

PoolFIFO: opt.PoolFIFO,
PoolSize: opt.PoolSize,
MaxConcurrentDials: opt.MaxConcurrentDials,
PoolTimeout: opt.PoolTimeout,
MinIdleConns: opt.MinIdleConns,
MaxIdleConns: opt.MaxIdleConns,
MaxActiveConns: opt.MaxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: opt.ReadBufferSize,
WriteBufferSize: opt.WriteBufferSize,
PoolFIFO: opt.PoolFIFO,
PoolSize: opt.PoolSize,
MaxConcurrentDials: opt.MaxConcurrentDials,
PoolTimeout: opt.PoolTimeout,
MinIdleConns: opt.MinIdleConns,
MaxIdleConns: opt.MaxIdleConns,
MaxActiveConns: opt.MaxActiveConns,
ConnMaxIdleTime: opt.ConnMaxIdleTime,
ConnMaxLifetime: opt.ConnMaxLifetime,
ReadBufferSize: opt.ReadBufferSize,
WriteBufferSize: opt.WriteBufferSize,
DisableIdentity: opt.DisableIdentity,
DisableIndentity: opt.DisableIdentity,
IdentitySuffix: opt.IdentitySuffix,
Expand Down Expand Up @@ -1025,14 +1032,16 @@ func (c *clusterState) slotNodes(slot int) []*clusterNode {
type clusterStateHolder struct {
load func(ctx context.Context) (*clusterState, error)

state atomic.Value
reloading uint32 // atomic
reloadPending uint32 // atomic - set to 1 when reload is requested during active reload
reloadInterval time.Duration
state atomic.Value
reloading uint32 // atomic
reloadPending uint32 // atomic - set to 1 when reload is requested during active reload
}

func newClusterStateHolder(load func(ctx context.Context) (*clusterState, error)) *clusterStateHolder {
func newClusterStateHolder(load func(ctx context.Context) (*clusterState, error), reloadInterval time.Duration) *clusterStateHolder {
return &clusterStateHolder{
load: load,
load: load,
reloadInterval: reloadInterval,
}
}

Expand Down Expand Up @@ -1087,7 +1096,7 @@ func (c *clusterStateHolder) Get(ctx context.Context) (*clusterState, error) {
}

state := v.(*clusterState)
if time.Since(state.createdAt) > 10*time.Second {
if time.Since(state.createdAt) > c.reloadInterval {
c.LazyReload()
}
return state, nil
Expand Down Expand Up @@ -1128,7 +1137,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient {

c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo)

c.state = newClusterStateHolder(c.loadState)
c.state = newClusterStateHolder(c.loadState, opt.ClusterStateReloadInterval)

c.SetCommandInfoResolver(NewDefaultCommandPolicyResolver())

Expand Down
14 changes: 7 additions & 7 deletions osscluster_lazy_reload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) {
reloadCount.Add(1)
time.Sleep(50 * time.Millisecond) // Simulate reload work
return &clusterState{}, nil
})
}, 10*time.Second)

// Trigger one reload
holder.LazyReload()
Expand All @@ -36,7 +36,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) {
reloadCount.Add(1)
time.Sleep(50 * time.Millisecond) // Simulate reload work
return &clusterState{}, nil
})
}, 10*time.Second)

// Trigger multiple reloads concurrently
for i := 0; i < 10; i++ {
Expand All @@ -59,7 +59,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) {
reloadCount.Add(1)
time.Sleep(10 * time.Millisecond) // Simulate reload work
return &clusterState{}, nil
})
}, 10*time.Second)

// Trigger first reload
holder.LazyReload()
Expand All @@ -86,7 +86,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) {
reloadCount.Add(1)
time.Sleep(10 * time.Millisecond) // Simulate reload work
return &clusterState{}, nil
})
}, 10*time.Second)

// Trigger first reload
holder.LazyReload()
Expand Down Expand Up @@ -118,7 +118,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) {
reloadCount.Add(1)
time.Sleep(10 * time.Millisecond) // Simulate reload work
return &clusterState{}, nil
})
}, 10*time.Second)

// Trigger first reload
holder.LazyReload()
Expand Down Expand Up @@ -149,7 +149,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) {
return nil, context.DeadlineExceeded
}
return &clusterState{}, nil
})
}, 10*time.Second)

// Trigger reload that will fail
holder.LazyReload()
Expand Down Expand Up @@ -179,7 +179,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) {
reloadCount.Add(1)
time.Sleep(20 * time.Millisecond) // Simulate realistic reload time
return &clusterState{}, nil
})
}, 10*time.Second)

// Simulate 5 SMIGRATED notifications arriving within 100ms
for i := 0; i < 5; i++ {
Expand Down
Loading
Loading