diff --git a/config/src/main/java/com/yahoo/vespa/config/Connection.java b/config/src/main/java/com/yahoo/vespa/config/Connection.java index ae850e68255f..b1f09be6623f 100644 --- a/config/src/main/java/com/yahoo/vespa/config/Connection.java +++ b/config/src/main/java/com/yahoo/vespa/config/Connection.java @@ -17,4 +17,6 @@ public interface Connection { String getAddress(); + default void closeConnection() {} + } diff --git a/config/src/main/java/com/yahoo/vespa/config/JRTConnection.java b/config/src/main/java/com/yahoo/vespa/config/JRTConnection.java index e88c8293eabf..984f59e7c050 100644 --- a/config/src/main/java/com/yahoo/vespa/config/JRTConnection.java +++ b/config/src/main/java/com/yahoo/vespa/config/JRTConnection.java @@ -43,6 +43,15 @@ public String getAddress() { return address; } + @Override + public synchronized void closeConnection() { + if (target != null) { + logger.log(Level.INFO, "Closing connection to " + address); + target.close(); + target = null; + } + } + /** * This is synchronized to avoid multiple ConfigInstances creating new targets simultaneously, if * the existing target is null, invalid or has not yet been initialized. diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java index c7a4d583ecf2..7a5959c086de 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java @@ -62,11 +62,27 @@ public FileDownloader(ConnectionPool connectionPool, this.timeout = timeout; // Needed to receive RPC receiveFile* calls from server after starting download of file reference new FileReceiver(supervisor, downloads, downloadDirectory); - this.fileReferenceDownloader = new FileReferenceDownloader(connectionPool, - downloads, - timeout, - backoffInitialTime, - downloadDirectory); + this.fileReferenceDownloader = new FileReferenceDownloader(connectionPool, downloads, timeout, + backoffInitialTime, downloadDirectory); + if (forceDownload) + log.log(Level.INFO, "Force download of file references (download even if file reference exists on disk)"); + } + + public FileDownloader(ConnectionPool connectionPool, + Supervisor supervisor, + File downloadDirectory, + Duration timeout, + Duration backoffInitialTime, + int maxTimeoutsBeforeClose) { + this.connectionPool = connectionPool; + this.supervisor = supervisor; + this.downloadDirectory = downloadDirectory; + this.timeout = timeout; + // Needed to receive RPC receiveFile* calls from server after starting download of file reference + new FileReceiver(supervisor, downloads, downloadDirectory); + this.fileReferenceDownloader = new FileReferenceDownloader(connectionPool, downloads, timeout, + backoffInitialTime, downloadDirectory, + maxTimeoutsBeforeClose); if (forceDownload) log.log(Level.INFO, "Force download of file references (download even if file reference exists on disk)"); } diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java index 843502da4124..132f5f049d07 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownloader.java @@ -3,6 +3,7 @@ import com.yahoo.concurrent.DaemonThreadFactory; import com.yahoo.config.FileReference; +import com.yahoo.jrt.ErrorCode; import com.yahoo.jrt.Int32Value; import com.yahoo.jrt.Request; import com.yahoo.jrt.Spec; @@ -39,6 +40,8 @@ public class FileReferenceDownloader { private static final Logger log = Logger.getLogger(FileReferenceDownloader.class.getName()); private static final Set defaultAcceptedCompressionTypes = Set.of(lz4, none, zstd); + private enum DownloadResult { SUCCESS, TIMEOUT, FAILURE } + private final ExecutorService downloadExecutor = Executors.newFixedThreadPool(Math.max(8, Runtime.getRuntime().availableProcessors()), new DaemonThreadFactory("filereference downloader")); @@ -49,17 +52,31 @@ public class FileReferenceDownloader { private final Optional rpcTimeout; // Only used when overridden with env variable private final File downloadDirectory; private final AtomicBoolean shutDown = new AtomicBoolean(false); + private final int maxTimeoutsBeforeClose; FileReferenceDownloader(ConnectionPool connectionPool, Downloads downloads, Duration timeout, Duration backoffInitialTime, File downloadDirectory) { + this(connectionPool, downloads, timeout, backoffInitialTime, downloadDirectory, + Optional.ofNullable(System.getenv("VESPA_FILE_DOWNLOAD_MAX_TIMEOUTS_BEFORE_CLOSE")) + .map(Integer::parseInt) + .orElse(0)); + } + + FileReferenceDownloader(ConnectionPool connectionPool, + Downloads downloads, + Duration timeout, + Duration backoffInitialTime, + File downloadDirectory, + int maxTimeoutsBeforeClose) { this.connectionPool = connectionPool; this.downloads = downloads; this.downloadTimeout = timeout; this.backoffInitialTime = backoffInitialTime; this.downloadDirectory = downloadDirectory; + this.maxTimeoutsBeforeClose = maxTimeoutsBeforeClose; // Undocumented on purpose, might change or be removed at any time var timeoutString = Optional.ofNullable(System.getenv("VESPA_FILE_DOWNLOAD_RPC_TIMEOUT")); this.rpcTimeout = timeoutString.map(t -> Duration.ofSeconds(Integer.parseInt(t))); @@ -69,6 +86,7 @@ private void waitUntilDownloadStarted(FileReferenceDownload fileReferenceDownloa Instant end = Instant.now().plus(downloadTimeout); FileReference fileReference = fileReferenceDownload.fileReference(); int retryCount = 0; + int timeoutCount = 0; Connection connection = connectionPool.getCurrent(); do { if (retryCount > 0) @@ -81,8 +99,19 @@ private void waitUntilDownloadStarted(FileReferenceDownload fileReferenceDownloa var timeout = rpcTimeout.orElse(Duration.between(Instant.now(), end)); log.log(Level.FINE, "Wait until download of " + fileReference + " has started, retryCount " + retryCount + ", timeout " + timeout + " (request from " + fileReferenceDownload.client() + ")"); - if ( ! timeout.isNegative() && startDownloadRpc(fileReferenceDownload, retryCount, connection, timeout)) - return; + if ( ! timeout.isNegative()) { + var result = startDownloadRpc(fileReferenceDownload, retryCount, connection, timeout); + if (result == DownloadResult.SUCCESS) return; + if (result == DownloadResult.TIMEOUT && maxTimeoutsBeforeClose > 0) { + timeoutCount++; + if (timeoutCount >= maxTimeoutsBeforeClose) { + log.log(Level.INFO, "RPC request for " + fileReference + " timed out " + timeoutCount + + " times, closing connection to " + connection.getAddress()); + connection.closeConnection(); + timeoutCount = 0; + } + } + } retryCount++; // There might not be one connection that works for all file references (each file reference might @@ -131,10 +160,13 @@ void startDownloadFromSource(FileReferenceDownload fileReferenceDownload, Spec s log.log(Level.FINE, () -> "Will download " + fileReference + " with timeout " + downloadTimeout + " from " + spec.host()); downloads.add(fileReferenceDownload); - var downloading = startDownloadRpc(fileReferenceDownload, 1, connection, downloadTimeout); + var result = startDownloadRpc(fileReferenceDownload, 1, connection, downloadTimeout); + if (result == DownloadResult.TIMEOUT && maxTimeoutsBeforeClose > 0) { + connection.closeConnection(); + } // Need to explicitly remove from downloads if downloading has not started. // If downloading *has* started FileReceiver will take care of that when download has completed or failed - if ( ! downloading) + if (result != DownloadResult.SUCCESS) downloads.remove(fileReference); }); } @@ -144,7 +176,7 @@ void failedDownloading(FileReference fileReference) { downloads.remove(fileReference); } - private boolean startDownloadRpc(FileReferenceDownload fileReferenceDownload, int retryCount, Connection connection, Duration timeout) { + private DownloadResult startDownloadRpc(FileReferenceDownload fileReferenceDownload, int retryCount, Connection connection, Duration timeout) { Request request = createRequest(fileReferenceDownload); connection.invokeSync(request, timeout); @@ -157,18 +189,18 @@ private boolean startDownloadRpc(FileReferenceDownload fileReferenceDownload, in if (errorCode == 0) { log.log(Level.FINE, () -> "Found " + fileReference + " available at " + address); - return true; + return DownloadResult.SUCCESS; } else { var error = FileApiErrorCodes.get(errorCode); log.log(logLevel, "Downloading " + fileReference + " from " + address + " failed (" + error + ")"); - return false; + return DownloadResult.FAILURE; } } else { log.log(logLevel, "Downloading " + fileReference + " from " + address + " (client " + fileReferenceDownload.client() + ") failed:" + " error code " + request.errorCode() + " (" + request.errorMessage() + ")." + " (retry " + retryCount + ", rpc timeout " + timeout + ")"); - return false; + return request.errorCode() == ErrorCode.TIMEOUT ? DownloadResult.TIMEOUT : DownloadResult.FAILURE; } } diff --git a/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java b/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java index 40dacd175e97..619c9ed162c3 100644 --- a/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java +++ b/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java @@ -32,6 +32,7 @@ import java.util.concurrent.Future; import static com.yahoo.jrt.ErrorCode.CONNECTION; +import static com.yahoo.jrt.ErrorCode.TIMEOUT; import static com.yahoo.vespa.filedistribution.FileReferenceData.CompressionType.zstd; import static com.yahoo.vespa.filedistribution.FileReferenceData.Type; import static com.yahoo.vespa.filedistribution.FileReferenceData.Type.compressed; @@ -264,6 +265,66 @@ public void receiveFile() throws IOException { assertEquals("content", IOUtils.readFile(downloadedFile)); } + @Test + public void testConnectionCloseOnTimeout() { + int timesToTimeout = 2; + MockConnection mockConnection = new MockConnection(); + MockConnection.TimeoutResponseHandler responseHandler = + new MockConnection.TimeoutResponseHandler(timesToTimeout); + mockConnection.setResponseHandler(responseHandler); + + FileDownloader downloader = new FileDownloader(mockConnection, supervisor, downloadDir, + Duration.ofSeconds(4), sleepBetweenRetries, + 1); + FileReference fileReference = new FileReference("timeoutTest"); + // File won't be found, download will fail after retries and timeout + assertFalse(downloader.getFile(new FileReferenceDownload(fileReference, "test")).isPresent()); + assertEquals("Expected closeConnection called for each timeout, got " + mockConnection.getCloseConnectionCount(), + timesToTimeout, mockConnection.getCloseConnectionCount()); + downloader.close(); + } + + @Test + public void testConnectionCloseAfterNTimeouts() { + int timesToTimeout = 6; + int retriesOnTimeoutBeforeClose = 2; + MockConnection mockConnection = new MockConnection(); + MockConnection.TimeoutResponseHandler responseHandler = + new MockConnection.TimeoutResponseHandler(timesToTimeout); + mockConnection.setResponseHandler(responseHandler); + + FileDownloader downloader = new FileDownloader(mockConnection, supervisor, downloadDir, + Duration.ofSeconds(4), sleepBetweenRetries, + retriesOnTimeoutBeforeClose); + FileReference fileReference = new FileReference("timeoutNTest"); + // File won't be found, download will fail after retries and timeout + assertFalse(downloader.getFile(new FileReferenceDownload(fileReference, "test")).isPresent()); + // With retriesOnTimeoutBeforeClose=2, close happens after the 2nd timeout (count 1,2 -> close), + // then counter resets, so 6 timeouts / 2 = 3 closes + assertEquals("Expected 3 closeConnection calls for 6 timeouts with threshold 2, got " + mockConnection.getCloseConnectionCount(), + 3, mockConnection.getCloseConnectionCount()); + downloader.close(); + } + + @Test + public void testNoConnectionCloseOnTimeoutByDefault() { + int timesToTimeout = 2; + MockConnection mockConnection = new MockConnection(); + MockConnection.TimeoutResponseHandler responseHandler = + new MockConnection.TimeoutResponseHandler(timesToTimeout); + mockConnection.setResponseHandler(responseHandler); + + // maxTimeoutsBeforeClose=0 means timeout-based close feature is disabled + FileDownloader downloader = new FileDownloader(mockConnection, supervisor, downloadDir, + Duration.ofSeconds(4), sleepBetweenRetries, + 0); + FileReference fileReference = new FileReference("timeoutDefaultTest"); + assertFalse(downloader.getFile(new FileReferenceDownload(fileReference, "test")).isPresent()); + assertEquals("Expected no closeConnection calls when feature is disabled, got " + mockConnection.getCloseConnectionCount(), + 0, mockConnection.getCloseConnectionCount()); + downloader.close(); + } + private void writeFileReference(File dir, String fileReferenceString, String fileName) throws IOException { File fileReferenceDir = new File(dir, fileReferenceString); fileReferenceDir.mkdir(); @@ -307,6 +368,7 @@ private FileDownloader createDownloader(MockConnection connection, Duration time private static class MockConnection implements ConnectionPool, com.yahoo.vespa.config.Connection { private ResponseHandler responseHandler; + private int closeConnectionCount = 0; MockConnection() { this(new FileReferenceFoundResponseHandler()); @@ -328,7 +390,16 @@ public void invokeSync(Request request, Duration jrtTimeout) { @Override public String getAddress() { - return null; + return "localhost"; + } + + @Override + public void closeConnection() { + closeConnectionCount++; + } + + int getCloseConnectionCount() { + return closeConnectionCount; } @Override @@ -403,6 +474,30 @@ public void request(Request request) { } } + static class TimeoutResponseHandler implements MockConnection.ResponseHandler { + + private final int timesToTimeout; + private int timedOutTimes = 0; + + TimeoutResponseHandler(int timesToTimeout) { + super(); + this.timesToTimeout = timesToTimeout; + } + + @Override + public void request(Request request) { + if (request.methodName().equals("filedistribution.serveFile")) { + if (timedOutTimes < timesToTimeout) { + request.setError(TIMEOUT, "Request timed out"); + timedOutTimes++; + } else { + request.returnValues().add(new Int32Value(0)); + request.returnValues().add(new StringValue("OK")); + } + } + } + } + static class ConnectionErrorResponseHandler implements MockConnection.ResponseHandler { private final int timesToFail;