diff --git a/maintnotifications/e2e/notification_injector.go b/maintnotifications/e2e/notification_injector.go index b265d1bebe..5a13664633 100644 --- a/maintnotifications/e2e/notification_injector.go +++ b/maintnotifications/e2e/notification_injector.go @@ -523,19 +523,37 @@ func formatSMigratingNotification(seqID int64, slots ...string) string { } func formatSMigratedNotification(seqID int64, endpoints ...string) string { - // New Format: ["SMIGRATED", SeqID, count, [endpoint1, endpoint2, ...]] + // Correct Format: ["SMIGRATED", SeqID, [[host:port, slots], [host:port, slots], ...]] + // RESP3 wire format: + // >3 + // +SMIGRATED + // :SeqID + // * + // *2 + // + + // + // Each endpoint is formatted as: "host:port slot1,slot2,range1-range2" - // Example: >4\r\n$9\r\nSMIGRATED\r\n:15\r\n:2\r\n*2\r\n$31\r\n127.0.0.1:6379 123,456,789-1000\r\n$30\r\n127.0.0.1:6380 124,457,300-500\r\n - parts := []string{">4\r\n"} - parts = append(parts, "$9\r\nSMIGRATED\r\n") + parts := []string{">3\r\n"} + parts = append(parts, "+SMIGRATED\r\n") parts = append(parts, fmt.Sprintf(":%d\r\n", seqID)) count := len(endpoints) - parts = append(parts, fmt.Sprintf(":%d\r\n", count)) parts = append(parts, fmt.Sprintf("*%d\r\n", count)) for _, endpoint := range endpoints { - parts = append(parts, fmt.Sprintf("$%d\r\n%s\r\n", len(endpoint), endpoint)) + // Split endpoint into host:port and slots + // endpoint format: "host:port slot1,slot2,range1-range2" + endpointParts := strings.SplitN(endpoint, " ", 2) + if len(endpointParts) != 2 { + continue + } + hostPort := endpointParts[0] + slots := endpointParts[1] + + // Each entry is an array with 2 elements + parts = append(parts, "*2\r\n") + parts = append(parts, fmt.Sprintf("+%s\r\n", hostPort)) + parts = append(parts, fmt.Sprintf("+%s\r\n", slots)) } return strings.Join(parts, "") diff --git a/maintnotifications/push_notification_handler.go b/maintnotifications/push_notification_handler.go index d5d0e48b52..93fee8fb28 100644 --- a/maintnotifications/push_notification_handler.go +++ b/maintnotifications/push_notification_handler.go @@ -341,12 +341,19 @@ func (snh *NotificationHandler) handleSMigrating(ctx context.Context, handlerCtx // handleSMigrated processes SMIGRATED notifications. // SMIGRATED indicates that a cluster slot has finished migrating to a different node. // This is a cluster-level notification that triggers cluster state reload. -// Expected format: ["SMIGRATED", SeqID, count, [endpoint1, endpoint2, ...]] -// Each endpoint is formatted as: "host:port slot1,slot2,range1-range2" +// Expected format: ["SMIGRATED", SeqID, [[host:port, slots], [host:port, slots], ...]] +// RESP3 wire format: +// >3 +// +SMIGRATED +// :SeqID +// * +// *2 +// + +// + // Note: Multiple connections may receive the same notification, so we deduplicate by SeqID before triggering reload. // but we still process the notification on each connection to clear the relaxed timeout. func (snh *NotificationHandler) handleSMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { - if len(notification) != 4 { + if len(notification) != 3 { internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED", notification)) return ErrInvalidNotification } @@ -358,33 +365,14 @@ func (snh *NotificationHandler) handleSMigrated(ctx context.Context, handlerCtx return ErrInvalidNotification } - // Extract count (position 2) - count, ok := notification[2].(int64) - if !ok { - internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (count)", notification)) - return ErrInvalidNotification - } - - // Extract endpoints array (position 3) - endpointsArray, ok := notification[3].([]interface{}) + // Extract endpoints array (position 2) + // Each entry is an array: [host:port, slots] + endpointsArray, ok := notification[2].([]interface{}) if !ok { internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (endpoints)", notification)) return ErrInvalidNotification } - if int64(len(endpointsArray)) != count { - internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (count mismatch)", notification)) - return ErrInvalidNotification - } - - // Parse endpoints - endpoints := make([]string, 0, count) - for _, ep := range endpointsArray { - if endpoint, ok := ep.(string); ok { - endpoints = append(endpoints, endpoint) - } - } - // Deduplicate by SeqID - multiple connections may receive the same notification if snh.manager.MarkSMigratedSeqIDProcessed(seqID) { // For logging and triggering reload, we use the first endpoint's host:port @@ -392,17 +380,31 @@ func (snh *NotificationHandler) handleSMigrated(ctx context.Context, handlerCtx var hostPort string var allSlotRanges []string - for _, endpoint := range endpoints { - // Parse endpoint: "host:port slot1,slot2,range1-range2" - parts := strings.SplitN(endpoint, " ", 2) - if len(parts) == 2 { - if hostPort == "" { - hostPort = parts[0] - } - // Split slots by comma - slots := strings.Split(parts[1], ",") - allSlotRanges = append(allSlotRanges, slots...) + for _, ep := range endpointsArray { + // Each endpoint is an array: [host:port, slots] + endpointParts, ok := ep.([]interface{}) + if !ok || len(endpointParts) != 2 { + internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (endpoint format)", ep)) + continue + } + + // Extract host:port (element 0) + hostPortStr, ok := endpointParts[0].(string) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (host:port)", endpointParts[0])) + continue } + + // Extract slots (element 1) + slotsStr, ok := endpointParts[1].(string) + if !ok { + internal.Logger.Printf(ctx, logs.InvalidNotification("SMIGRATED (slots)", endpointParts[1])) + continue + } + + hostPort = hostPortStr + slotRanges := strings.Split(slotsStr, ",") + allSlotRanges = append(allSlotRanges, slotRanges...) } if internal.LogLevel.InfoOrAbove() { diff --git a/maintnotifications/push_notification_handler_test.go b/maintnotifications/push_notification_handler_test.go new file mode 100644 index 0000000000..c8439cc701 --- /dev/null +++ b/maintnotifications/push_notification_handler_test.go @@ -0,0 +1,74 @@ +package maintnotifications + +import ( + "context" + "testing" + + "github.com/redis/go-redis/v9/push" +) + +// TestHandleSMigrated_CorrectFormat tests that handleSMigrated correctly parses the correct format +func TestHandleSMigrated_CorrectFormat(t *testing.T) { + // Create a minimal manager for testing + config := DefaultConfig() + manager := &Manager{ + config: config, + } + + handler := &NotificationHandler{ + manager: manager, + } + + // Create notification in the correct format: + // ["SMIGRATED", SeqID, [[host:port, slots], [host:port, slots], ...]] + notification := []interface{}{ + "SMIGRATED", + int64(12346), + []interface{}{ + []interface{}{"127.0.0.1:6379", "123,456,789-1000"}, + []interface{}{"127.0.0.1:6380", "124,457,300-500"}, + }, + } + + ctx := context.Background() + handlerCtx := push.NotificationHandlerContext{ + Conn: nil, // No connection needed for this test + } + + // This should not return an error + err := handler.handleSMigrated(ctx, handlerCtx, notification) + if err != nil { + t.Errorf("handleSMigrated failed with correct format: %v", err) + } +} + +// TestHandleSMigrated_SingleEndpoint tests parsing with a single endpoint +func TestHandleSMigrated_SingleEndpoint(t *testing.T) { + config := DefaultConfig() + manager := &Manager{ + config: config, + } + + handler := &NotificationHandler{ + manager: manager, + } + + notification := []interface{}{ + "SMIGRATED", + int64(100), + []interface{}{ + []interface{}{"127.0.0.1:6380", "1000,2000-3000"}, + }, + } + + ctx := context.Background() + handlerCtx := push.NotificationHandlerContext{ + Conn: nil, + } + + err := handler.handleSMigrated(ctx, handlerCtx, notification) + if err != nil { + t.Errorf("handleSMigrated failed with single endpoint: %v", err) + } +} + diff --git a/osscluster.go b/osscluster.go index 19b915c648..1c5b47bd6d 100644 --- a/osscluster.go +++ b/osscluster.go @@ -179,6 +179,10 @@ type ClusterOptions struct { // ShardPicker is used to pick a shard when the request_policy is // ReqDefault and the command has no keys. ShardPicker routing.ShardPicker + + // ClusterStateReloadInterval is the interval for reloading the cluster state. + // Default is 10 seconds. + ClusterStateReloadInterval time.Duration } func (opt *ClusterOptions) init() { @@ -258,6 +262,10 @@ func (opt *ClusterOptions) init() { if opt.ShardPicker == nil { opt.ShardPicker = &routing.RoundRobinPicker{} } + + if opt.ClusterStateReloadInterval == 0 { + opt.ClusterStateReloadInterval = 10 * time.Second + } } // ParseClusterURL parses a URL into ClusterOptions that can be used to connect to Redis. @@ -422,17 +430,17 @@ func (opt *ClusterOptions) clientOptions() *Options { ContextTimeoutEnabled: opt.ContextTimeoutEnabled, - PoolFIFO: opt.PoolFIFO, - PoolSize: opt.PoolSize, - MaxConcurrentDials: opt.MaxConcurrentDials, - PoolTimeout: opt.PoolTimeout, - MinIdleConns: opt.MinIdleConns, - MaxIdleConns: opt.MaxIdleConns, - MaxActiveConns: opt.MaxActiveConns, - ConnMaxIdleTime: opt.ConnMaxIdleTime, - ConnMaxLifetime: opt.ConnMaxLifetime, - ReadBufferSize: opt.ReadBufferSize, - WriteBufferSize: opt.WriteBufferSize, + PoolFIFO: opt.PoolFIFO, + PoolSize: opt.PoolSize, + MaxConcurrentDials: opt.MaxConcurrentDials, + PoolTimeout: opt.PoolTimeout, + MinIdleConns: opt.MinIdleConns, + MaxIdleConns: opt.MaxIdleConns, + MaxActiveConns: opt.MaxActiveConns, + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, + ReadBufferSize: opt.ReadBufferSize, + WriteBufferSize: opt.WriteBufferSize, DisableIdentity: opt.DisableIdentity, DisableIndentity: opt.DisableIdentity, IdentitySuffix: opt.IdentitySuffix, @@ -1025,14 +1033,16 @@ func (c *clusterState) slotNodes(slot int) []*clusterNode { type clusterStateHolder struct { load func(ctx context.Context) (*clusterState, error) - state atomic.Value - reloading uint32 // atomic - reloadPending uint32 // atomic - set to 1 when reload is requested during active reload + reloadInterval time.Duration + state atomic.Value + reloading uint32 // atomic + reloadPending uint32 // atomic - set to 1 when reload is requested during active reload } -func newClusterStateHolder(load func(ctx context.Context) (*clusterState, error)) *clusterStateHolder { +func newClusterStateHolder(load func(ctx context.Context) (*clusterState, error), reloadInterval time.Duration) *clusterStateHolder { return &clusterStateHolder{ - load: load, + load: load, + reloadInterval: reloadInterval, } } @@ -1087,7 +1097,7 @@ func (c *clusterStateHolder) Get(ctx context.Context) (*clusterState, error) { } state := v.(*clusterState) - if time.Since(state.createdAt) > 10*time.Second { + if time.Since(state.createdAt) > c.reloadInterval { c.LazyReload() } return state, nil @@ -1128,7 +1138,7 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { c.cmdsInfoCache = newCmdsInfoCache(c.cmdsInfo) - c.state = newClusterStateHolder(c.loadState) + c.state = newClusterStateHolder(c.loadState, opt.ClusterStateReloadInterval) c.SetCommandInfoResolver(NewDefaultCommandPolicyResolver()) diff --git a/osscluster_lazy_reload_test.go b/osscluster_lazy_reload_test.go index 994fd40e74..a09f2a9119 100644 --- a/osscluster_lazy_reload_test.go +++ b/osscluster_lazy_reload_test.go @@ -16,7 +16,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) { reloadCount.Add(1) time.Sleep(50 * time.Millisecond) // Simulate reload work return &clusterState{}, nil - }) + }, 10*time.Second) // Trigger one reload holder.LazyReload() @@ -36,7 +36,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) { reloadCount.Add(1) time.Sleep(50 * time.Millisecond) // Simulate reload work return &clusterState{}, nil - }) + }, 10*time.Second) // Trigger multiple reloads concurrently for i := 0; i < 10; i++ { @@ -59,7 +59,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) { reloadCount.Add(1) time.Sleep(10 * time.Millisecond) // Simulate reload work return &clusterState{}, nil - }) + }, 10*time.Second) // Trigger first reload holder.LazyReload() @@ -86,7 +86,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) { reloadCount.Add(1) time.Sleep(10 * time.Millisecond) // Simulate reload work return &clusterState{}, nil - }) + }, 10*time.Second) // Trigger first reload holder.LazyReload() @@ -118,7 +118,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) { reloadCount.Add(1) time.Sleep(10 * time.Millisecond) // Simulate reload work return &clusterState{}, nil - }) + }, 10*time.Second) // Trigger first reload holder.LazyReload() @@ -149,7 +149,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) { return nil, context.DeadlineExceeded } return &clusterState{}, nil - }) + }, 10*time.Second) // Trigger reload that will fail holder.LazyReload() @@ -179,7 +179,7 @@ func TestLazyReloadQueueBehavior(t *testing.T) { reloadCount.Add(1) time.Sleep(20 * time.Millisecond) // Simulate realistic reload time return &clusterState{}, nil - }) + }, 10*time.Second) // Simulate 5 SMIGRATED notifications arriving within 100ms for i := 0; i < 5; i++ { diff --git a/osscluster_maintnotifications_unit_test.go b/osscluster_maintnotifications_unit_test.go index 6e7331795b..7a18fa898e 100644 --- a/osscluster_maintnotifications_unit_test.go +++ b/osscluster_maintnotifications_unit_test.go @@ -100,22 +100,20 @@ func TestClusterMaintNotifications_SMigratingHandler(t *testing.T) { // TestClusterMaintNotifications_SMigratedHandler tests SMIGRATED notification handling func TestClusterMaintNotifications_SMigratedHandler(t *testing.T) { - // Simulate receiving a SMIGRATED notification with new format - // Format: ["SMIGRATED", SeqID, count, [endpoint1, endpoint2, ...]] - // Each endpoint is "host:port slot1,slot2,range1-range2" + // Simulate receiving a SMIGRATED notification with correct format + // Format: ["SMIGRATED", SeqID, [[host:port, slots], [host:port, slots], ...]] notification := []interface{}{ "SMIGRATED", int64(12346), - int64(2), // count of endpoints []interface{}{ - "127.0.0.1:6379 123,456,789-1000", - "127.0.0.1:6380 124,457,300-500", + []interface{}{"127.0.0.1:6379", "123,456,789-1000"}, + []interface{}{"127.0.0.1:6380", "124,457,300-500"}, }, } // Verify notification format - if len(notification) != 4 { - t.Fatalf("SMIGRATED notification should have exactly 4 elements, got %d", len(notification)) + if len(notification) != 3 { + t.Fatalf("SMIGRATED notification should have exactly 3 elements, got %d", len(notification)) } notifType, ok := notification[0].(string) @@ -131,40 +129,47 @@ func TestClusterMaintNotifications_SMigratedHandler(t *testing.T) { t.Errorf("Expected SeqID 12346, got %d", seqID) } - // Verify count - count, ok := notification[2].(int64) - if !ok { - t.Fatalf("Expected count to be int64, got %T", notification[2]) - } - if count != 2 { - t.Errorf("Expected count 2, got %d", count) - } - // Verify endpoints array - endpoints, ok := notification[3].([]interface{}) + endpoints, ok := notification[2].([]interface{}) if !ok { - t.Fatalf("Expected endpoints to be array, got %T", notification[3]) + t.Fatalf("Expected endpoints to be array, got %T", notification[2]) } - if len(endpoints) != int(count) { - t.Errorf("Expected %d endpoints, got %d", count, len(endpoints)) + if len(endpoints) != 2 { + t.Errorf("Expected 2 endpoints, got %d", len(endpoints)) } // Verify first endpoint - endpoint1, ok := endpoints[0].(string) + endpoint1, ok := endpoints[0].([]interface{}) if !ok { - t.Fatalf("Expected endpoint to be string, got %T", endpoints[0]) + t.Fatalf("Expected endpoint to be array, got %T", endpoints[0]) } - if endpoint1 != "127.0.0.1:6379 123,456,789-1000" { - t.Errorf("Expected first endpoint '127.0.0.1:6379 123,456,789-1000', got %v", endpoint1) + if len(endpoint1) != 2 { + t.Fatalf("Expected endpoint to have 2 elements, got %d", len(endpoint1)) + } + hostPort1, ok := endpoint1[0].(string) + if !ok || hostPort1 != "127.0.0.1:6379" { + t.Errorf("Expected first endpoint host:port '127.0.0.1:6379', got %v", endpoint1[0]) + } + slots1, ok := endpoint1[1].(string) + if !ok || slots1 != "123,456,789-1000" { + t.Errorf("Expected first endpoint slots '123,456,789-1000', got %v", endpoint1[1]) } // Verify second endpoint - endpoint2, ok := endpoints[1].(string) + endpoint2, ok := endpoints[1].([]interface{}) if !ok { - t.Fatalf("Expected endpoint to be string, got %T", endpoints[1]) + t.Fatalf("Expected endpoint to be array, got %T", endpoints[1]) + } + if len(endpoint2) != 2 { + t.Fatalf("Expected endpoint to have 2 elements, got %d", len(endpoint2)) + } + hostPort2, ok := endpoint2[0].(string) + if !ok || hostPort2 != "127.0.0.1:6380" { + t.Errorf("Expected second endpoint host:port '127.0.0.1:6380', got %v", endpoint2[0]) } - if endpoint2 != "127.0.0.1:6380 124,457,300-500" { - t.Errorf("Expected second endpoint '127.0.0.1:6380 124,457,300-500', got %v", endpoint2) + slots2, ok := endpoint2[1].(string) + if !ok || slots2 != "124,457,300-500" { + t.Errorf("Expected second endpoint slots '124,457,300-500', got %v", endpoint2[1]) } t.Log("SMIGRATED notification format validation passed")