diff --git a/hosting/Windows/Garnet.worker/Program.cs b/hosting/Windows/Garnet.worker/Program.cs index 8418da86716..bb2e46173ea 100644 --- a/hosting/Windows/Garnet.worker/Program.cs +++ b/hosting/Windows/Garnet.worker/Program.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +using System; using Garnet; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; @@ -12,6 +13,13 @@ static void Main(string[] args) var builder = Host.CreateApplicationBuilder(args); builder.Services.AddHostedService(_ => new Worker(args)); + // Configure Host shutdown timeout + builder.Services.Configure(options => + { + // Set graceful shutdown timeout to 5 seconds + options.ShutdownTimeout = TimeSpan.FromSeconds(5); + }); + builder.Services.AddWindowsService(options => { options.ServiceName = "Microsoft Garnet Server"; diff --git a/hosting/Windows/Garnet.worker/Worker.cs b/hosting/Windows/Garnet.worker/Worker.cs index d69adb7e3c0..133aeea966a 100644 --- a/hosting/Windows/Garnet.worker/Worker.cs +++ b/hosting/Windows/Garnet.worker/Worker.cs @@ -43,8 +43,23 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) /// Indicates that the shutdown process should no longer be graceful. public override async Task StopAsync(CancellationToken cancellationToken) { - Dispose(); - await base.StopAsync(cancellationToken).ConfigureAwait(false); + try + { + if (server != null) + { + // Perform graceful shutdown with AOF commit and checkpoint + await server.ShutdownAsync(timeout: TimeSpan.FromSeconds(5), token: cancellationToken).ConfigureAwait(false); + } + } + catch (OperationCanceledException) + { + // Force shutdown requested - proceed to dispose + } + finally + { + await base.StopAsync(cancellationToken).ConfigureAwait(false); + Dispose(); + } } public override void Dispose() @@ -55,6 +70,8 @@ public override void Dispose() } server?.Dispose(); _isDisposed = true; + base.Dispose(); + GC.SuppressFinalize(this); } } } \ No newline at end of file diff --git a/libs/host/GarnetServer.cs b/libs/host/GarnetServer.cs index 2d12f43a0d4..9cf79c6d2d3 100644 --- a/libs/host/GarnetServer.cs +++ b/libs/host/GarnetServer.cs @@ -10,6 +10,7 @@ using System.Runtime.InteropServices; using System.Text; using System.Threading; +using System.Threading.Tasks; using Garnet.cluster; using Garnet.common; using Garnet.networking; @@ -422,6 +423,177 @@ public void Start() Console.WriteLine("* Ready to accept connections"); } + /// + /// Performs graceful shutdown of the server. + /// Stops accepting new connections, waits for active connections to complete, commits AOF, and takes checkpoint if needed. + /// + /// Timeout for waiting on active connections (default: 30 seconds) + /// Cancellation token + /// Task representing the async shutdown operation + public async Task ShutdownAsync(TimeSpan? timeout = null, CancellationToken token = default) + { + var shutdownTimeout = timeout ?? TimeSpan.FromSeconds(30); + + try + { + // Stop accepting new connections first + StopListening(); + + // Wait for existing connections to complete + await WaitForActiveConnectionsAsync(shutdownTimeout, token).ConfigureAwait(false); + + // Commit AOF and take checkpoint if needed + await FinalizeDataAsync(token).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Force shutdown requested + } + catch (Exception ex) + { + logger?.LogError(ex, "Error during graceful shutdown"); + } + } + + /// + /// Stop all servers from accepting new connections. + /// + private void StopListening() + { + if (servers == null) return; + + logger?.LogInformation("Stopping listeners to prevent new connections..."); + foreach (var server in servers) + { + try + { + server?.StopListening(); + } + catch (Exception ex) + { + logger?.LogWarning(ex, "Error stopping listener"); + } + } + } + + /// + /// Waits for active connections to complete within the specified timeout. + /// + private async Task WaitForActiveConnectionsAsync(TimeSpan timeout, CancellationToken token) + { + if (servers == null) return; + + var stopwatch = Stopwatch.StartNew(); + var delays = new[] { 50, 300, 1000 }; + var delayIndex = 0; + + while (stopwatch.Elapsed < timeout && !token.IsCancellationRequested) + { + try + { + var activeConnections = GetActiveConnectionCount(); + + if (activeConnections == 0) + { + logger?.LogInformation("All connections have been closed gracefully."); + return; + } + + logger?.LogInformation("Waiting for {ActiveConnections} active connections to complete...", activeConnections); + + var currentDelay = delays[delayIndex]; + if (delayIndex < delays.Length - 1) delayIndex++; + + await Task.Delay(currentDelay, token).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + throw; + } + catch (Exception ex) + { + logger?.LogWarning(ex, "Error checking active connections"); + delayIndex = 0; + await Task.Delay(500, token).ConfigureAwait(false); + } + } + + if (stopwatch.Elapsed >= timeout) + { + logger?.LogWarning("Timeout reached after {TimeoutSeconds} seconds. Some connections may still be active.", + timeout.TotalSeconds); + } + } + + /// + /// Gets the current number of active connections directly from server instances. + /// + private int GetActiveConnectionCount() + { + int count = 0; + if (servers != null) + { + foreach (var garnetServerBase in servers.OfType()) + { + count += (int)garnetServerBase.get_conn_active(); + } + } + return count; + } + + /// + /// Commits AOF and takes checkpoint for data durability during shutdown. + /// + private async Task FinalizeDataAsync(CancellationToken token) + { + var enableAOF = opts.EnableAOF; + var enableStorageTier = opts.EnableStorageTier; + + // Commit AOF before checkpoint/shutdown + if (enableAOF) + { + logger?.LogInformation("Committing AOF before shutdown..."); + try + { + var commitSuccess = await Store.CommitAOFAsync(token).ConfigureAwait(false); + if (commitSuccess) + { + logger?.LogInformation("AOF committed successfully."); + } + else + { + logger?.LogInformation("AOF commit skipped (another commit in progress or replica mode)."); + } + } + catch (Exception ex) + { + logger?.LogError(ex, "Error committing AOF during shutdown"); + } + } + + // Take checkpoint for tiered storage + if (enableStorageTier) + { + logger?.LogInformation("Taking checkpoint for tiered storage..."); + try + { + var checkpointSuccess = Store.TakeCheckpoint(background: false, token); + if (checkpointSuccess) + { + logger?.LogInformation("Checkpoint completed successfully."); + } + else + { + logger?.LogInformation("Checkpoint skipped (another checkpoint in progress or replica mode)."); + } + } + catch (Exception ex) + { + logger?.LogError(ex, "Error taking checkpoint during shutdown"); + } + } + } + /// /// Dispose store (including log and checkpoint directory) /// diff --git a/libs/server/Servers/GarnetServerBase.cs b/libs/server/Servers/GarnetServerBase.cs index 5bfd1ff62ff..7f6386ac523 100644 --- a/libs/server/Servers/GarnetServerBase.cs +++ b/libs/server/Servers/GarnetServerBase.cs @@ -154,6 +154,12 @@ public bool AddSession(WireFormat protocol, ref ISessionProvider provider, INetw /// public abstract void Start(); + /// + public virtual void StopListening() + { + // Base implementation does nothing; derived classes should override + } + /// public virtual void Dispose() { diff --git a/libs/server/Servers/GarnetServerTcp.cs b/libs/server/Servers/GarnetServerTcp.cs index c681e09befa..be0ea105a4b 100644 --- a/libs/server/Servers/GarnetServerTcp.cs +++ b/libs/server/Servers/GarnetServerTcp.cs @@ -28,6 +28,7 @@ public class GarnetServerTcp : GarnetServerBase, IServerHook readonly int networkConnectionLimit; readonly string unixSocketPath; readonly UnixFileMode unixSocketPermission; + volatile bool isListening; /// public override IEnumerable ActiveConsumers() @@ -117,19 +118,43 @@ public override void Start() } listenSocket.Listen(512); + isListening = true; if (!listenSocket.AcceptAsync(acceptEventArg)) AcceptEventArg_Completed(null, acceptEventArg); } + /// + public override void StopListening() + { + if (!isListening) + return; + + isListening = false; + try + { + // Close the listen socket to stop accepting new connections + // This will cause any pending AcceptAsync to complete with an error + listenSocket.Close(); + logger?.LogInformation("Stopped accepting new connections on {endpoint}", EndPoint); + } + catch (Exception ex) + { + logger?.LogWarning(ex, "Error closing listen socket on {endpoint}", EndPoint); + } + } + private void AcceptEventArg_Completed(object sender, SocketAsyncEventArgs e) { try { do { + // Check isListening flag before processing and before calling AcceptAsync again + if (!isListening) break; + if (!HandleNewConnection(e)) break; e.AcceptSocket = null; - } while (!listenSocket.AcceptAsync(e)); + } while (isListening && !listenSocket.AcceptAsync(e)); } // socket disposed catch (ObjectDisposedException) { } diff --git a/libs/server/Servers/IGarnetServer.cs b/libs/server/Servers/IGarnetServer.cs index 9e197451627..639e4a2a259 100644 --- a/libs/server/Servers/IGarnetServer.cs +++ b/libs/server/Servers/IGarnetServer.cs @@ -46,5 +46,11 @@ public interface IGarnetServer : IDisposable /// Start server /// public void Start(); + + /// + /// Stop accepting new connections (for graceful shutdown). + /// Existing connections remain active until they complete or are disposed. + /// + public void StopListening(); } } \ No newline at end of file diff --git a/libs/server/Servers/StoreApi.cs b/libs/server/Servers/StoreApi.cs index 5ff169c9fd5..2ed70ef7a3e 100644 --- a/libs/server/Servers/StoreApi.cs +++ b/libs/server/Servers/StoreApi.cs @@ -130,6 +130,35 @@ public bool FlushDB(int dbId = 0, bool unsafeTruncateLog = false) } } + /// + /// Take checkpoint for all active databases + /// + /// True if method can return before checkpoint is taken + /// Cancellation token + /// false if checkpoint was skipped due to node state or another checkpoint in progress + public bool TakeCheckpoint(bool background = false, CancellationToken token = default) + { + using (PreventRoleChange(out var acquired)) + { + if (!acquired || IsReplica) + { + return false; + } + + return storeWrapper.TakeCheckpoint(background, logger: null, token: token); + } + } + + /// + /// Check if storage tier is enabled + /// + public bool IsStorageTierEnabled => storeWrapper.serverOptions.EnableStorageTier; + + /// + /// Check if AOF is enabled + /// + public bool IsAOFEnabled => storeWrapper.serverOptions.EnableAOF; + /// /// Helper to disable role changes during a using block. /// diff --git a/main/GarnetServer/Program.cs b/main/GarnetServer/Program.cs index 7b2673ebc41..f7463e4db8f 100644 --- a/main/GarnetServer/Program.cs +++ b/main/GarnetServer/Program.cs @@ -10,23 +10,75 @@ namespace Garnet /// public class Program { - static void Main(string[] args) + static async Task Main(string[] args) { + GarnetServer server = null; + using var shutdownCts = new CancellationTokenSource(); + int shutdownInitiated = 0; // Guard to ensure single shutdown/dispose + int serverStarted = 0; // Guard to track if server started successfully + try { - using var server = new GarnetServer(args); + server = new GarnetServer(args); // Optional: register custom extensions RegisterExtensions(server); + // Set up graceful shutdown handlers for Ctrl+C and SIGTERM + Console.CancelKeyPress += (sender, e) => + { + e.Cancel = true; // Prevent immediate termination + Console.WriteLine("Shutdown signal received. Starting graceful shutdown..."); + shutdownCts.Cancel(); + }; + + AppDomain.CurrentDomain.ProcessExit += (sender, e) => + { + // Only initiate shutdown if not already done and server has started + if (Interlocked.Exchange(ref shutdownInitiated, 1) == 0 && + Interlocked.CompareExchange(ref serverStarted, 0, 0) == 1) + { + Console.WriteLine("Process exit signal received. Starting graceful shutdown..."); + shutdownCts.Cancel(); + // Wait for graceful shutdown with timeout + server?.ShutdownAsync(TimeSpan.FromSeconds(3), CancellationToken.None) + .GetAwaiter().GetResult(); + server?.Dispose(); + } + }; + // Start the server server.Start(); + Interlocked.Exchange(ref serverStarted, 1); // Mark server as started + + // Wait for shutdown signal + try + { + await Task.Delay(Timeout.Infinite, shutdownCts.Token).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + // Normal shutdown path + } - Thread.Sleep(Timeout.Infinite); + // Only initiate shutdown if not already done by ProcessExit handler and server has started + if (Interlocked.Exchange(ref shutdownInitiated, 1) == 0 && + Interlocked.CompareExchange(ref serverStarted, 0, 0) == 1) + { + // Block synchronously for shutdown - ensures cleanup completes before process exits + server.ShutdownAsync(TimeSpan.FromSeconds(5), CancellationToken.None) + .GetAwaiter().GetResult(); + server.Dispose(); + } } catch (Exception ex) { Console.WriteLine($"Unable to initialize server due to exception: {ex.Message}"); + // Ensure cleanup on exception if shutdown wasn't initiated + if (Interlocked.Exchange(ref shutdownInitiated, 1) == 0) + { + server?.Dispose(); + } } } diff --git a/test/Garnet.test/GarnetServerTcpTests.cs b/test/Garnet.test/GarnetServerTcpTests.cs new file mode 100644 index 00000000000..852f0332a6c --- /dev/null +++ b/test/Garnet.test/GarnetServerTcpTests.cs @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Allure.NUnit; +using Garnet.server; +using NUnit.Framework; +using NUnit.Framework.Legacy; +using StackExchange.Redis; + +namespace Garnet.test +{ + [AllureNUnit] + [TestFixture, NonParallelizable] + public class GarnetServerTcpTests : AllureTestBase + { + private GarnetServer server; + + [SetUp] + public void Setup() + { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir); + server.Start(); + } + + [TearDown] + public void TearDown() + { + server?.Dispose(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); + } + + [Test] + public void StopListeningPreventsNewConnections() + { + // Arrange - Establish a working connection first + using var redis1 = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db1 = redis1.GetDatabase(0); + db1.StringSet("test", "value"); + ClassicAssert.AreEqual("value", (string)db1.StringGet("test")); + + // Act - Stop listening on all servers + foreach (var tcpServer in server.Provider.StoreWrapper.Servers.OfType()) + { + tcpServer.StopListening(); + } + + Thread.Sleep(100); // Brief delay to ensure socket is closed + + // Assert - New connections should fail + Assert.Throws(() => + { + using var redis2 = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + redis2.GetDatabase(0).Ping(); + }); + + // Existing connection should still work + ClassicAssert.AreEqual("value", (string)db1.StringGet("test")); + } + + [Test] + public void StopListeningIdempotent() + { + // Arrange + foreach (var tcpServer in server.Provider.StoreWrapper.Servers.OfType()) + { + tcpServer.StopListening(); + } + + // Act & Assert - Calling StopListening again should not throw + Assert.DoesNotThrow(() => + { + foreach (var tcpServer in server.Provider.StoreWrapper.Servers.OfType()) + { + tcpServer.StopListening(); + } + }); + } + + [Test] + public void StopListeningLogsInformation() + { + // This test verifies that StopListening logs appropriate information + // You would need to set up a logger and verify the log output + // For now, we just verify no exceptions are thrown + + Assert.DoesNotThrow(() => + { + foreach (var tcpServer in server.Provider.StoreWrapper.Servers.OfType()) + { + tcpServer.StopListening(); + } + }); + } + + [Test] + public async Task StopListeningDuringActiveConnectionAttempts() + { + // Arrange - Start multiple connection attempts + var connectionTasks = new System.Collections.Generic.List(); + var cts = new CancellationTokenSource(); + + for (int i = 0; i < 10; i++) + { + connectionTasks.Add(Task.Run(async () => + { + while (!cts.Token.IsCancellationRequested) + { + try + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + await redis.GetDatabase(0).PingAsync(); + await Task.Delay(10); + } + catch + { + // Connection failures are expected after StopListening + } + } + }, cts.Token)); + } + + await Task.Delay(50); // Let some connections establish + + // Act + foreach (var tcpServer in server.Provider.StoreWrapper.Servers.OfType()) + { + tcpServer.StopListening(); + } + + await Task.Delay(100); + cts.Cancel(); + + // Assert - All tasks should complete without unhandled exceptions + Assert.DoesNotThrowAsync(async () => await Task.WhenAll(connectionTasks)); + } + } +} \ No newline at end of file diff --git a/test/Garnet.test/RespAdminCommandsTests.cs b/test/Garnet.test/RespAdminCommandsTests.cs index 6fa5b9dd509..dae1f31577a 100644 --- a/test/Garnet.test/RespAdminCommandsTests.cs +++ b/test/Garnet.test/RespAdminCommandsTests.cs @@ -662,5 +662,162 @@ public void ConfigGetWrongNumberOfArguments() ClassicAssert.AreEqual(expectedMessage, ex.Message); } #endregion + + #region GracefulShutdownTests + [Test] + public async Task ShutdownAsyncStopsAcceptingNewConnections() + { + // Arrange + server.Dispose(); + var testServer = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir + "_shutdown"); + testServer.Start(); + + using var redis1 = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db1 = redis1.GetDatabase(0); + db1.StringSet("test", "value"); + + // Act - Initiate shutdown (no need for Task.Run, ShutdownAsync is already async) + var shutdownTask = testServer.ShutdownAsync(TimeSpan.FromSeconds(5)); + + // Give shutdown a moment to stop listening + await Task.Delay(200); + + // Assert - New connections should fail + var ex = Assert.ThrowsAsync(async () => + { + using var redis2 = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + await redis2.GetDatabase(0).PingAsync(); + }); + ClassicAssert.IsNotNull(ex, "Expected connection to fail after shutdown initiated"); + + await shutdownTask; + testServer.Dispose(); + } + + [Test] + public async Task ShutdownAsyncWaitsForActiveConnections() + { + // Arrange + server.Dispose(); + var testServer = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir + "_shutdown2"); + testServer.Start(); + + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // Set initial value + db.StringSet("key1", "value1"); + + // Act - Start shutdown while connection is active + var shutdownTask = testServer.ShutdownAsync(TimeSpan.FromSeconds(10)); + + // Connection should still work during grace period + // Perform multiple operations to ensure connection remains active + var result = db.StringGet("key1"); + ClassicAssert.AreEqual("value1", (string)result); + + // Verify we can still perform operations during grace period + db.StringSet("key2", "value2"); + var result2 = db.StringGet("key2"); + ClassicAssert.AreEqual("value2", (string)result2); + + await shutdownTask; + testServer.Dispose(); + } + + [Test] + public async Task ShutdownAsyncCommitsAOF() + { + // Arrange + server.Dispose(); + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, enableAOF: true); + server.Start(); + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + db.StringSet("aofKey", "aofValue"); + } + + // Act - Shutdown which should commit AOF + await server.ShutdownAsync(TimeSpan.FromSeconds(5)); + server.Dispose(false); + + // Assert - Recover and verify data persisted + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, enableAOF: true, tryRecover: true); + server.Start(); + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig())) + { + var db = redis.GetDatabase(0); + var recoveredValue = db.StringGet("aofKey"); + ClassicAssert.AreEqual("aofValue", recoveredValue.ToString()); + } + } + + [Test] + public async Task ShutdownAsyncTakesCheckpointWhenStorageTierEnabled() + { + // Arrange + server.Dispose(); + // Storage tier is enabled by default when logCheckpointDir is provided + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir); + server.Start(); + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + db.StringSet("checkpointKey", "checkpointValue"); + } + + // Act - Shutdown which should take checkpoint + await server.ShutdownAsync(TimeSpan.FromSeconds(5)); + server.Dispose(false); + + // Assert - Recover from checkpoint + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true); + server.Start(); + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig())) + { + var db = redis.GetDatabase(0); + var recoveredValue = db.StringGet("checkpointKey"); + ClassicAssert.AreEqual("checkpointValue", recoveredValue.ToString()); + } + } + + [Test] + public async Task ShutdownAsyncRespectsTimeout() + { + // This test verifies that shutdown respects the timeout parameter + // Arrange + server.Dispose(); + var testServer = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir + "_timeout"); + testServer.Start(); + + // Create a connection that will remain active + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + db.StringSet("key", "value"); + + // Act - Shutdown with very short timeout (100ms) + // With an active connection, shutdown should timeout quickly rather than waiting indefinitely + var stopwatch = System.Diagnostics.Stopwatch.StartNew(); + await testServer.ShutdownAsync(TimeSpan.FromMilliseconds(100)); + stopwatch.Stop(); + + // Assert - Should complete within reasonable time (timeout + some overhead for AOF/checkpoint) + // The timeout is for waiting on connections, but shutdown also does AOF commit and checkpoint + // So we allow more time than the timeout itself + ClassicAssert.Less(stopwatch.ElapsedMilliseconds, 5000, + $"Shutdown should complete within reasonable time. Actual: {stopwatch.ElapsedMilliseconds}ms"); + + // Verify it completed faster than a longer timeout would take + ClassicAssert.Less(stopwatch.ElapsedMilliseconds, 2000, + "Shutdown with short timeout should be faster than longer timeout"); + + testServer.Dispose(); + } + #endregion } } \ No newline at end of file