diff --git a/config/src/main/java/com/yahoo/vespa/config/Connection.java b/config/src/main/java/com/yahoo/vespa/config/Connection.java index ae850e68255f..b1f09be6623f 100644 --- a/config/src/main/java/com/yahoo/vespa/config/Connection.java +++ b/config/src/main/java/com/yahoo/vespa/config/Connection.java @@ -17,4 +17,6 @@ public interface Connection { String getAddress(); + default void closeConnection() {} + } diff --git a/config/src/main/java/com/yahoo/vespa/config/JRTConnection.java b/config/src/main/java/com/yahoo/vespa/config/JRTConnection.java index e88c8293eabf..eb076559fc79 100644 --- a/config/src/main/java/com/yahoo/vespa/config/JRTConnection.java +++ b/config/src/main/java/com/yahoo/vespa/config/JRTConnection.java @@ -43,6 +43,21 @@ public String getAddress() { return address; } + public synchronized void closeTarget() { + if (target != null) { + logger.log(Level.INFO, "Force-closing connection to " + address + + " (target valid=" + target.isValid() + ")"); + target.closeSocket(); // Synchronously close TCP socket before async cleanup + target.close(); + target = null; + } + } + + @Override + public void closeConnection() { + closeTarget(); + } + /** * This is synchronized to avoid multiple ConfigInstances creating new targets simultaneously, if * the existing target is null, invalid or has not yet been initialized. diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java index c7a4d583ecf2..7a5959c086de 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java @@ -62,11 +62,27 @@ public FileDownloader(ConnectionPool connectionPool, this.timeout = timeout; // Needed to receive RPC receiveFile* calls from server after starting download of file reference new FileReceiver(supervisor, downloads, downloadDirectory); - this.fileReferenceDownloader = new FileReferenceDownloader(connectionPool, - downloads, - timeout, - backoffInitialTime, - downloadDirectory); + this.fileReferenceDownloader = new FileReferenceDownloader(connectionPool, downloads, timeout, + backoffInitialTime, downloadDirectory); + if (forceDownload) + log.log(Level.INFO, "Force download of file references (download even if file reference exists on disk)"); + } + + public FileDownloader(ConnectionPool connectionPool, + Supervisor supervisor, + File downloadDirectory, + Duration timeout, + Duration backoffInitialTime, + int maxTimeoutsBeforeClose) { + this.connectionPool = connectionPool; + this.supervisor = supervisor; + this.downloadDirectory = downloadDirectory; + this.timeout = timeout; + // Needed to receive RPC receiveFile* calls from server after starting download of file reference + new FileReceiver(supervisor, downloads, downloadDirectory); + this.fileReferenceDownloader = new FileReferenceDownloader(connectionPool, downloads, timeout, + backoffInitialTime, downloadDirectory, + maxTimeoutsBeforeClose); if (forceDownload) log.log(Level.INFO, "Force download of file references (download even if file reference exists on disk)"); } diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java index 843502da4124..073f27ed8ef9 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java @@ -3,6 +3,7 @@ import com.yahoo.concurrent.DaemonThreadFactory; import com.yahoo.config.FileReference; +import com.yahoo.jrt.ErrorCode; import com.yahoo.jrt.Int32Value; import com.yahoo.jrt.Request; import com.yahoo.jrt.Spec; @@ -39,6 +40,8 @@ public class FileReferenceDownloader { private static final Logger log = Logger.getLogger(FileReferenceDownloader.class.getName()); private static final Set defaultAcceptedCompressionTypes = Set.of(lz4, none, zstd); + private enum DownloadResult { SUCCESS, TIMEOUT, FAILURE } + private final ExecutorService downloadExecutor = Executors.newFixedThreadPool(Math.max(8, Runtime.getRuntime().availableProcessors()), new DaemonThreadFactory("filereference downloader")); @@ -49,17 +52,31 @@ public class FileReferenceDownloader { private final Optional rpcTimeout; // Only used when overridden with env variable private final File downloadDirectory; private final AtomicBoolean shutDown = new AtomicBoolean(false); + private final int maxTimeoutsBeforeClose; FileReferenceDownloader(ConnectionPool connectionPool, Downloads downloads, Duration timeout, Duration backoffInitialTime, File downloadDirectory) { + this(connectionPool, downloads, timeout, backoffInitialTime, downloadDirectory, + Optional.ofNullable(System.getenv("VESPA_FILE_DOWNLOAD_MAX_TIMEOUTS_BEFORE_CLOSE")) + .map(Integer::parseInt) + .orElse(0)); + } + + FileReferenceDownloader(ConnectionPool connectionPool, + Downloads downloads, + Duration timeout, + Duration backoffInitialTime, + File downloadDirectory, + int maxTimeoutsBeforeClose) { this.connectionPool = connectionPool; this.downloads = downloads; this.downloadTimeout = timeout; this.backoffInitialTime = backoffInitialTime; this.downloadDirectory = downloadDirectory; + this.maxTimeoutsBeforeClose = maxTimeoutsBeforeClose; // Undocumented on purpose, might change or be removed at any time var timeoutString = Optional.ofNullable(System.getenv("VESPA_FILE_DOWNLOAD_RPC_TIMEOUT")); this.rpcTimeout = timeoutString.map(t -> Duration.ofSeconds(Integer.parseInt(t))); @@ -69,6 +86,7 @@ private void waitUntilDownloadStarted(FileReferenceDownload fileReferenceDownloa Instant end = Instant.now().plus(downloadTimeout); FileReference fileReference = fileReferenceDownload.fileReference(); int retryCount = 0; + int timeoutCount = 0; Connection connection = connectionPool.getCurrent(); do { if (retryCount > 0) @@ -81,8 +99,19 @@ private void waitUntilDownloadStarted(FileReferenceDownload fileReferenceDownloa var timeout = rpcTimeout.orElse(Duration.between(Instant.now(), end)); log.log(Level.FINE, "Wait until download of " + fileReference + " has started, retryCount " + retryCount + ", timeout " + timeout + " (request from " + fileReferenceDownload.client() + ")"); - if ( ! timeout.isNegative() && startDownloadRpc(fileReferenceDownload, retryCount, connection, timeout)) - return; + if ( ! timeout.isNegative()) { + var result = startDownloadRpc(fileReferenceDownload, retryCount, connection, timeout); + if (result == DownloadResult.SUCCESS) return; + if (result == DownloadResult.TIMEOUT && maxTimeoutsBeforeClose > 0) { + timeoutCount++; + if (timeoutCount >= maxTimeoutsBeforeClose) { + log.log(Level.INFO, "RPC request for " + fileReference + " timed out " + timeoutCount + + " times, force-closing connection to " + connection.getAddress()); + connection.closeConnection(); + timeoutCount = 0; + } + } + } retryCount++; // There might not be one connection that works for all file references (each file reference might @@ -131,10 +160,13 @@ void startDownloadFromSource(FileReferenceDownload fileReferenceDownload, Spec s log.log(Level.FINE, () -> "Will download " + fileReference + " with timeout " + downloadTimeout + " from " + spec.host()); downloads.add(fileReferenceDownload); - var downloading = startDownloadRpc(fileReferenceDownload, 1, connection, downloadTimeout); + var result = startDownloadRpc(fileReferenceDownload, 1, connection, downloadTimeout); + if (result == DownloadResult.TIMEOUT && maxTimeoutsBeforeClose > 0) { + connection.closeConnection(); + } // Need to explicitly remove from downloads if downloading has not started. // If downloading *has* started FileReceiver will take care of that when download has completed or failed - if ( ! downloading) + if (result != DownloadResult.SUCCESS) downloads.remove(fileReference); }); } @@ -144,7 +176,7 @@ void failedDownloading(FileReference fileReference) { downloads.remove(fileReference); } - private boolean startDownloadRpc(FileReferenceDownload fileReferenceDownload, int retryCount, Connection connection, Duration timeout) { + private DownloadResult startDownloadRpc(FileReferenceDownload fileReferenceDownload, int retryCount, Connection connection, Duration timeout) { Request request = createRequest(fileReferenceDownload); connection.invokeSync(request, timeout); @@ -157,18 +189,18 @@ private boolean startDownloadRpc(FileReferenceDownload fileReferenceDownload, in if (errorCode == 0) { log.log(Level.FINE, () -> "Found " + fileReference + " available at " + address); - return true; + return DownloadResult.SUCCESS; } else { var error = FileApiErrorCodes.get(errorCode); log.log(logLevel, "Downloading " + fileReference + " from " + address + " failed (" + error + ")"); - return false; + return DownloadResult.FAILURE; } } else { log.log(logLevel, "Downloading " + fileReference + " from " + address + " (client " + fileReferenceDownload.client() + ") failed:" + " error code " + request.errorCode() + " (" + request.errorMessage() + ")." + " (retry " + retryCount + ", rpc timeout " + timeout + ")"); - return false; + return request.errorCode() == ErrorCode.TIMEOUT ? DownloadResult.TIMEOUT : DownloadResult.FAILURE; } } diff --git a/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java b/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java index 40dacd175e97..619c9ed162c3 100644 --- a/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java +++ b/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java @@ -32,6 +32,7 @@ import java.util.concurrent.Future; import static com.yahoo.jrt.ErrorCode.CONNECTION; +import static com.yahoo.jrt.ErrorCode.TIMEOUT; import static com.yahoo.vespa.filedistribution.FileReferenceData.CompressionType.zstd; import static com.yahoo.vespa.filedistribution.FileReferenceData.Type; import static com.yahoo.vespa.filedistribution.FileReferenceData.Type.compressed; @@ -264,6 +265,66 @@ public void receiveFile() throws IOException { assertEquals("content", IOUtils.readFile(downloadedFile)); } + @Test + public void testConnectionCloseOnTimeout() { + int timesToTimeout = 2; + MockConnection mockConnection = new MockConnection(); + MockConnection.TimeoutResponseHandler responseHandler = + new MockConnection.TimeoutResponseHandler(timesToTimeout); + mockConnection.setResponseHandler(responseHandler); + + FileDownloader downloader = new FileDownloader(mockConnection, supervisor, downloadDir, + Duration.ofSeconds(4), sleepBetweenRetries, + 1); + FileReference fileReference = new FileReference("timeoutTest"); + // File won't be found, download will fail after retries and timeout + assertFalse(downloader.getFile(new FileReferenceDownload(fileReference, "test")).isPresent()); + assertEquals("Expected closeConnection called for each timeout, got " + mockConnection.getCloseConnectionCount(), + timesToTimeout, mockConnection.getCloseConnectionCount()); + downloader.close(); + } + + @Test + public void testConnectionCloseAfterNTimeouts() { + int timesToTimeout = 6; + int retriesOnTimeoutBeforeClose = 2; + MockConnection mockConnection = new MockConnection(); + MockConnection.TimeoutResponseHandler responseHandler = + new MockConnection.TimeoutResponseHandler(timesToTimeout); + mockConnection.setResponseHandler(responseHandler); + + FileDownloader downloader = new FileDownloader(mockConnection, supervisor, downloadDir, + Duration.ofSeconds(4), sleepBetweenRetries, + retriesOnTimeoutBeforeClose); + FileReference fileReference = new FileReference("timeoutNTest"); + // File won't be found, download will fail after retries and timeout + assertFalse(downloader.getFile(new FileReferenceDownload(fileReference, "test")).isPresent()); + // With retriesOnTimeoutBeforeClose=2, close happens after the 2nd timeout (count 1,2 -> close), + // then counter resets, so 6 timeouts / 2 = 3 closes + assertEquals("Expected 3 closeConnection calls for 6 timeouts with threshold 2, got " + mockConnection.getCloseConnectionCount(), + 3, mockConnection.getCloseConnectionCount()); + downloader.close(); + } + + @Test + public void testNoConnectionCloseOnTimeoutByDefault() { + int timesToTimeout = 2; + MockConnection mockConnection = new MockConnection(); + MockConnection.TimeoutResponseHandler responseHandler = + new MockConnection.TimeoutResponseHandler(timesToTimeout); + mockConnection.setResponseHandler(responseHandler); + + // maxTimeoutsBeforeClose=0 means timeout-based close feature is disabled + FileDownloader downloader = new FileDownloader(mockConnection, supervisor, downloadDir, + Duration.ofSeconds(4), sleepBetweenRetries, + 0); + FileReference fileReference = new FileReference("timeoutDefaultTest"); + assertFalse(downloader.getFile(new FileReferenceDownload(fileReference, "test")).isPresent()); + assertEquals("Expected no closeConnection calls when feature is disabled, got " + mockConnection.getCloseConnectionCount(), + 0, mockConnection.getCloseConnectionCount()); + downloader.close(); + } + private void writeFileReference(File dir, String fileReferenceString, String fileName) throws IOException { File fileReferenceDir = new File(dir, fileReferenceString); fileReferenceDir.mkdir(); @@ -307,6 +368,7 @@ private FileDownloader createDownloader(MockConnection connection, Duration time private static class MockConnection implements ConnectionPool, com.yahoo.vespa.config.Connection { private ResponseHandler responseHandler; + private int closeConnectionCount = 0; MockConnection() { this(new FileReferenceFoundResponseHandler()); @@ -328,7 +390,16 @@ public void invokeSync(Request request, Duration jrtTimeout) { @Override public String getAddress() { - return null; + return "localhost"; + } + + @Override + public void closeConnection() { + closeConnectionCount++; + } + + int getCloseConnectionCount() { + return closeConnectionCount; } @Override @@ -403,6 +474,30 @@ public void request(Request request) { } } + static class TimeoutResponseHandler implements MockConnection.ResponseHandler { + + private final int timesToTimeout; + private int timedOutTimes = 0; + + TimeoutResponseHandler(int timesToTimeout) { + super(); + this.timesToTimeout = timesToTimeout; + } + + @Override + public void request(Request request) { + if (request.methodName().equals("filedistribution.serveFile")) { + if (timedOutTimes < timesToTimeout) { + request.setError(TIMEOUT, "Request timed out"); + timedOutTimes++; + } else { + request.returnValues().add(new Int32Value(0)); + request.returnValues().add(new StringValue("OK")); + } + } + } + } + static class ConnectionErrorResponseHandler implements MockConnection.ResponseHandler { private final int timesToFail; diff --git a/jrt/src/com/yahoo/jrt/Connection.java b/jrt/src/com/yahoo/jrt/Connection.java index c757a3240c1e..30ba7976cccc 100644 --- a/jrt/src/com/yahoo/jrt/Connection.java +++ b/jrt/src/com/yahoo/jrt/Connection.java @@ -48,6 +48,7 @@ class Connection extends Target { private final Supervisor owner; private final Spec spec; private CryptoSocket socket; + private volatile SocketChannel channelForClose; private int readSize = READ_SIZE; private final boolean server; private final AtomicLong requestId = new AtomicLong(0); @@ -90,6 +91,7 @@ public Connection(TransportThread parent, Supervisor owner, this.parent = parent; this.owner = owner; this.socket = parent.transport().createServerCryptoSocket(channel); + this.channelForClose = socket.channel(); this.spec = null; this.tcpNoDelay = tcpNoDelay; maxInputSize = owner.getMaxInputBufferSize(); @@ -165,6 +167,7 @@ public Connection connect() { } try { socket = parent.transport().createClientCryptoSocket(SocketChannel.open(spec.resolveAddress()), spec); + channelForClose = socket.channel(); // volatile write for cross-thread visibility } catch (Exception e) { setLostReason(e); } @@ -178,6 +181,8 @@ public boolean init(Selector selector) { try { socket.channel().configureBlocking(false); socket.channel().socket().setTcpNoDelay(tcpNoDelay); + socket.channel().socket().setKeepAlive(true); + socket.channel().socket().setSoLinger(true, 0); selectionKey = socket.channel().register(selector, SelectionKey.OP_READ | SelectionKey.OP_WRITE, this); @@ -395,12 +400,30 @@ public boolean hasSocket() { return ((socket != null) && (socket.channel() != null)); } + @Override public void closeSocket() { - if (hasSocket()) { - try { - socket.channel().socket().close(); - } catch (Exception e) { - log.log(Level.WARNING, "Error closing connection", e); + SocketChannel ch = channelForClose; // volatile read — guaranteed cross-thread visibility + if (ch != null) { + if (ch.isOpen()) { + String socketInfo = ""; + try { + socketInfo = "local=" + ch.socket().getLocalPort() + + " remote=" + ch.socket().getRemoteSocketAddress(); + } catch (Exception ignored) {} + try { + log.log(Level.INFO, "Closing socket channel: " + socketInfo); + ch.close(); + } catch (Exception e) { + log.log(Level.WARNING, "Error closing socket channel: " + socketInfo, e); + } + } + } else { + if (hasSocket()) { + try { + socket.channel().close(); + } catch (Exception e) { + log.log(Level.WARNING, "Error closing connection", e); + } } } } diff --git a/jrt/src/com/yahoo/jrt/Target.java b/jrt/src/com/yahoo/jrt/Target.java index a8253b025fd2..f32b025f58b8 100644 --- a/jrt/src/com/yahoo/jrt/Target.java +++ b/jrt/src/com/yahoo/jrt/Target.java @@ -161,6 +161,12 @@ private static double toSeconds(Duration duration) { */ public abstract boolean removeWatcher(TargetWatcher watcher); + /** + * Synchronously close the underlying TCP socket. This is a no-op by default; + * subclasses backed by a real socket connection should override this. + */ + public void closeSocket() {} + /** * Close this target. Note that the close operation is * asynchronous. If you need to wait for the target to become diff --git a/jrt/tests/com/yahoo/jrt/ConnectionTest.java b/jrt/tests/com/yahoo/jrt/ConnectionTest.java new file mode 100644 index 000000000000..ce481f668963 --- /dev/null +++ b/jrt/tests/com/yahoo/jrt/ConnectionTest.java @@ -0,0 +1,42 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jrt; + +import static org.junit.Assert.assertTrue; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class ConnectionTest { + + @org.junit.Test + public void closeSocketIsVisibleAcrossThreads() throws Exception { + Test.Orb server = new Test.Orb(new Transport()); + Test.Orb client = new Test.Orb(new Transport()); + Acceptor acceptor = server.listen(new Spec(0)); + Target target = client.connect(new Spec("localhost", acceptor.port())); + + // Wait for connection to be established + for (int i = 0; i < 100 && !target.isValid(); i++) { + Thread.sleep(10); + } + assertTrue(target.isValid()); + + // Close the socket from a different thread + CountDownLatch done = new CountDownLatch(1); + boolean[] closed = {false}; + new Thread(() -> { + target.closeSocket(); + closed[0] = true; + done.countDown(); + }).start(); + + assertTrue(done.await(5, TimeUnit.SECONDS)); + assertTrue(closed[0]); + + target.close(); + acceptor.shutdown().join(); + client.transport().shutdown().join(); + server.transport().shutdown().join(); + } + +}