Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions config/src/main/java/com/yahoo/vespa/config/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ public interface Connection {

String getAddress();

default void closeConnection() {}

}
15 changes: 15 additions & 0 deletions config/src/main/java/com/yahoo/vespa/config/JRTConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@ public String getAddress() {
return address;
}

public synchronized void closeTarget() {
if (target != null) {
logger.log(Level.INFO, "Force-closing connection to " + address +
" (target valid=" + target.isValid() + ")");
target.closeSocket(); // Synchronously close TCP socket before async cleanup
target.close();
target = null;
}
}

@Override
public void closeConnection() {
closeTarget();
}

/**
* This is synchronized to avoid multiple ConfigInstances creating new targets simultaneously, if
* the existing target is null, invalid or has not yet been initialized.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,27 @@ public FileDownloader(ConnectionPool connectionPool,
this.timeout = timeout;
// Needed to receive RPC receiveFile* calls from server after starting download of file reference
new FileReceiver(supervisor, downloads, downloadDirectory);
this.fileReferenceDownloader = new FileReferenceDownloader(connectionPool,
downloads,
timeout,
backoffInitialTime,
downloadDirectory);
this.fileReferenceDownloader = new FileReferenceDownloader(connectionPool, downloads, timeout,
backoffInitialTime, downloadDirectory);
if (forceDownload)
log.log(Level.INFO, "Force download of file references (download even if file reference exists on disk)");
}

public FileDownloader(ConnectionPool connectionPool,
Supervisor supervisor,
File downloadDirectory,
Duration timeout,
Duration backoffInitialTime,
int maxTimeoutsBeforeClose) {
this.connectionPool = connectionPool;
this.supervisor = supervisor;
this.downloadDirectory = downloadDirectory;
this.timeout = timeout;
// Needed to receive RPC receiveFile* calls from server after starting download of file reference
new FileReceiver(supervisor, downloads, downloadDirectory);
this.fileReferenceDownloader = new FileReferenceDownloader(connectionPool, downloads, timeout,
backoffInitialTime, downloadDirectory,
maxTimeoutsBeforeClose);
if (forceDownload)
log.log(Level.INFO, "Force download of file references (download even if file reference exists on disk)");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import com.yahoo.concurrent.DaemonThreadFactory;
import com.yahoo.config.FileReference;
import com.yahoo.jrt.ErrorCode;
import com.yahoo.jrt.Int32Value;
import com.yahoo.jrt.Request;
import com.yahoo.jrt.Spec;
Expand Down Expand Up @@ -39,6 +40,8 @@ public class FileReferenceDownloader {
private static final Logger log = Logger.getLogger(FileReferenceDownloader.class.getName());
private static final Set<CompressionType> defaultAcceptedCompressionTypes = Set.of(lz4, none, zstd);

private enum DownloadResult { SUCCESS, TIMEOUT, FAILURE }

private final ExecutorService downloadExecutor =
Executors.newFixedThreadPool(Math.max(8, Runtime.getRuntime().availableProcessors()),
new DaemonThreadFactory("filereference downloader"));
Expand All @@ -49,17 +52,31 @@ public class FileReferenceDownloader {
private final Optional<Duration> rpcTimeout; // Only used when overridden with env variable
private final File downloadDirectory;
private final AtomicBoolean shutDown = new AtomicBoolean(false);
private final int maxTimeoutsBeforeClose;

FileReferenceDownloader(ConnectionPool connectionPool,
Downloads downloads,
Duration timeout,
Duration backoffInitialTime,
File downloadDirectory) {
this(connectionPool, downloads, timeout, backoffInitialTime, downloadDirectory,
Optional.ofNullable(System.getenv("VESPA_FILE_DOWNLOAD_MAX_TIMEOUTS_BEFORE_CLOSE"))
.map(Integer::parseInt)
.orElse(0));
}

FileReferenceDownloader(ConnectionPool connectionPool,
Downloads downloads,
Duration timeout,
Duration backoffInitialTime,
File downloadDirectory,
int maxTimeoutsBeforeClose) {
this.connectionPool = connectionPool;
this.downloads = downloads;
this.downloadTimeout = timeout;
this.backoffInitialTime = backoffInitialTime;
this.downloadDirectory = downloadDirectory;
this.maxTimeoutsBeforeClose = maxTimeoutsBeforeClose;
// Undocumented on purpose, might change or be removed at any time
var timeoutString = Optional.ofNullable(System.getenv("VESPA_FILE_DOWNLOAD_RPC_TIMEOUT"));
this.rpcTimeout = timeoutString.map(t -> Duration.ofSeconds(Integer.parseInt(t)));
Expand All @@ -69,6 +86,7 @@ private void waitUntilDownloadStarted(FileReferenceDownload fileReferenceDownloa
Instant end = Instant.now().plus(downloadTimeout);
FileReference fileReference = fileReferenceDownload.fileReference();
int retryCount = 0;
int timeoutCount = 0;
Connection connection = connectionPool.getCurrent();
do {
if (retryCount > 0)
Expand All @@ -81,8 +99,19 @@ private void waitUntilDownloadStarted(FileReferenceDownload fileReferenceDownloa
var timeout = rpcTimeout.orElse(Duration.between(Instant.now(), end));
log.log(Level.FINE, "Wait until download of " + fileReference + " has started, retryCount " + retryCount +
", timeout " + timeout + " (request from " + fileReferenceDownload.client() + ")");
if ( ! timeout.isNegative() && startDownloadRpc(fileReferenceDownload, retryCount, connection, timeout))
return;
if ( ! timeout.isNegative()) {
var result = startDownloadRpc(fileReferenceDownload, retryCount, connection, timeout);
if (result == DownloadResult.SUCCESS) return;
if (result == DownloadResult.TIMEOUT && maxTimeoutsBeforeClose > 0) {
timeoutCount++;
if (timeoutCount >= maxTimeoutsBeforeClose) {
log.log(Level.INFO, "RPC request for " + fileReference + " timed out " + timeoutCount +
" times, force-closing connection to " + connection.getAddress());
connection.closeConnection();
timeoutCount = 0;
}
}
}

retryCount++;
// There might not be one connection that works for all file references (each file reference might
Expand Down Expand Up @@ -131,10 +160,13 @@ void startDownloadFromSource(FileReferenceDownload fileReferenceDownload, Spec s

log.log(Level.FINE, () -> "Will download " + fileReference + " with timeout " + downloadTimeout + " from " + spec.host());
downloads.add(fileReferenceDownload);
var downloading = startDownloadRpc(fileReferenceDownload, 1, connection, downloadTimeout);
var result = startDownloadRpc(fileReferenceDownload, 1, connection, downloadTimeout);
if (result == DownloadResult.TIMEOUT && maxTimeoutsBeforeClose > 0) {
connection.closeConnection();
}
// Need to explicitly remove from downloads if downloading has not started.
// If downloading *has* started FileReceiver will take care of that when download has completed or failed
if ( ! downloading)
if (result != DownloadResult.SUCCESS)
downloads.remove(fileReference);
});
}
Expand All @@ -144,7 +176,7 @@ void failedDownloading(FileReference fileReference) {
downloads.remove(fileReference);
}

private boolean startDownloadRpc(FileReferenceDownload fileReferenceDownload, int retryCount, Connection connection, Duration timeout) {
private DownloadResult startDownloadRpc(FileReferenceDownload fileReferenceDownload, int retryCount, Connection connection, Duration timeout) {
Request request = createRequest(fileReferenceDownload);
connection.invokeSync(request, timeout);

Expand All @@ -157,18 +189,18 @@ private boolean startDownloadRpc(FileReferenceDownload fileReferenceDownload, in

if (errorCode == 0) {
log.log(Level.FINE, () -> "Found " + fileReference + " available at " + address);
return true;
return DownloadResult.SUCCESS;
} else {
var error = FileApiErrorCodes.get(errorCode);
log.log(logLevel, "Downloading " + fileReference + " from " + address + " failed (" + error + ")");
return false;
return DownloadResult.FAILURE;
}
} else {
log.log(logLevel, "Downloading " + fileReference + " from " + address +
" (client " + fileReferenceDownload.client() + ") failed:" +
" error code " + request.errorCode() + " (" + request.errorMessage() + ")." +
" (retry " + retryCount + ", rpc timeout " + timeout + ")");
return false;
return request.errorCode() == ErrorCode.TIMEOUT ? DownloadResult.TIMEOUT : DownloadResult.FAILURE;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.concurrent.Future;

import static com.yahoo.jrt.ErrorCode.CONNECTION;
import static com.yahoo.jrt.ErrorCode.TIMEOUT;
import static com.yahoo.vespa.filedistribution.FileReferenceData.CompressionType.zstd;
import static com.yahoo.vespa.filedistribution.FileReferenceData.Type;
import static com.yahoo.vespa.filedistribution.FileReferenceData.Type.compressed;
Expand Down Expand Up @@ -264,6 +265,66 @@ public void receiveFile() throws IOException {
assertEquals("content", IOUtils.readFile(downloadedFile));
}

@Test
public void testConnectionCloseOnTimeout() {
int timesToTimeout = 2;
MockConnection mockConnection = new MockConnection();
MockConnection.TimeoutResponseHandler responseHandler =
new MockConnection.TimeoutResponseHandler(timesToTimeout);
mockConnection.setResponseHandler(responseHandler);

FileDownloader downloader = new FileDownloader(mockConnection, supervisor, downloadDir,
Duration.ofSeconds(4), sleepBetweenRetries,
1);
FileReference fileReference = new FileReference("timeoutTest");
// File won't be found, download will fail after retries and timeout
assertFalse(downloader.getFile(new FileReferenceDownload(fileReference, "test")).isPresent());
assertEquals("Expected closeConnection called for each timeout, got " + mockConnection.getCloseConnectionCount(),
timesToTimeout, mockConnection.getCloseConnectionCount());
downloader.close();
}

@Test
public void testConnectionCloseAfterNTimeouts() {
int timesToTimeout = 6;
int retriesOnTimeoutBeforeClose = 2;
MockConnection mockConnection = new MockConnection();
MockConnection.TimeoutResponseHandler responseHandler =
new MockConnection.TimeoutResponseHandler(timesToTimeout);
mockConnection.setResponseHandler(responseHandler);

FileDownloader downloader = new FileDownloader(mockConnection, supervisor, downloadDir,
Duration.ofSeconds(4), sleepBetweenRetries,
retriesOnTimeoutBeforeClose);
FileReference fileReference = new FileReference("timeoutNTest");
// File won't be found, download will fail after retries and timeout
assertFalse(downloader.getFile(new FileReferenceDownload(fileReference, "test")).isPresent());
// With retriesOnTimeoutBeforeClose=2, close happens after the 2nd timeout (count 1,2 -> close),
// then counter resets, so 6 timeouts / 2 = 3 closes
assertEquals("Expected 3 closeConnection calls for 6 timeouts with threshold 2, got " + mockConnection.getCloseConnectionCount(),
3, mockConnection.getCloseConnectionCount());
downloader.close();
}

@Test
public void testNoConnectionCloseOnTimeoutByDefault() {
int timesToTimeout = 2;
MockConnection mockConnection = new MockConnection();
MockConnection.TimeoutResponseHandler responseHandler =
new MockConnection.TimeoutResponseHandler(timesToTimeout);
mockConnection.setResponseHandler(responseHandler);

// maxTimeoutsBeforeClose=0 means timeout-based close feature is disabled
FileDownloader downloader = new FileDownloader(mockConnection, supervisor, downloadDir,
Duration.ofSeconds(4), sleepBetweenRetries,
0);
FileReference fileReference = new FileReference("timeoutDefaultTest");
assertFalse(downloader.getFile(new FileReferenceDownload(fileReference, "test")).isPresent());
assertEquals("Expected no closeConnection calls when feature is disabled, got " + mockConnection.getCloseConnectionCount(),
0, mockConnection.getCloseConnectionCount());
downloader.close();
}

private void writeFileReference(File dir, String fileReferenceString, String fileName) throws IOException {
File fileReferenceDir = new File(dir, fileReferenceString);
fileReferenceDir.mkdir();
Expand Down Expand Up @@ -307,6 +368,7 @@ private FileDownloader createDownloader(MockConnection connection, Duration time
private static class MockConnection implements ConnectionPool, com.yahoo.vespa.config.Connection {

private ResponseHandler responseHandler;
private int closeConnectionCount = 0;

MockConnection() {
this(new FileReferenceFoundResponseHandler());
Expand All @@ -328,7 +390,16 @@ public void invokeSync(Request request, Duration jrtTimeout) {

@Override
public String getAddress() {
return null;
return "localhost";
}

@Override
public void closeConnection() {
closeConnectionCount++;
}

int getCloseConnectionCount() {
return closeConnectionCount;
}

@Override
Expand Down Expand Up @@ -403,6 +474,30 @@ public void request(Request request) {
}
}

static class TimeoutResponseHandler implements MockConnection.ResponseHandler {

private final int timesToTimeout;
private int timedOutTimes = 0;

TimeoutResponseHandler(int timesToTimeout) {
super();
this.timesToTimeout = timesToTimeout;
}

@Override
public void request(Request request) {
if (request.methodName().equals("filedistribution.serveFile")) {
if (timedOutTimes < timesToTimeout) {
request.setError(TIMEOUT, "Request timed out");
timedOutTimes++;
} else {
request.returnValues().add(new Int32Value(0));
request.returnValues().add(new StringValue("OK"));
}
}
}
}

static class ConnectionErrorResponseHandler implements MockConnection.ResponseHandler {

private final int timesToFail;
Expand Down
33 changes: 28 additions & 5 deletions jrt/src/com/yahoo/jrt/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class Connection extends Target {
private final Supervisor owner;
private final Spec spec;
private CryptoSocket socket;
private volatile SocketChannel channelForClose;
private int readSize = READ_SIZE;
private final boolean server;
private final AtomicLong requestId = new AtomicLong(0);
Expand Down Expand Up @@ -90,6 +91,7 @@ public Connection(TransportThread parent, Supervisor owner,
this.parent = parent;
this.owner = owner;
this.socket = parent.transport().createServerCryptoSocket(channel);
this.channelForClose = socket.channel();
this.spec = null;
this.tcpNoDelay = tcpNoDelay;
maxInputSize = owner.getMaxInputBufferSize();
Expand Down Expand Up @@ -165,6 +167,7 @@ public Connection connect() {
}
try {
socket = parent.transport().createClientCryptoSocket(SocketChannel.open(spec.resolveAddress()), spec);
channelForClose = socket.channel(); // volatile write for cross-thread visibility
} catch (Exception e) {
setLostReason(e);
}
Expand All @@ -178,6 +181,8 @@ public boolean init(Selector selector) {
try {
socket.channel().configureBlocking(false);
socket.channel().socket().setTcpNoDelay(tcpNoDelay);
socket.channel().socket().setKeepAlive(true);
socket.channel().socket().setSoLinger(true, 0);
selectionKey = socket.channel().register(selector,
SelectionKey.OP_READ | SelectionKey.OP_WRITE,
this);
Expand Down Expand Up @@ -395,12 +400,30 @@ public boolean hasSocket() {
return ((socket != null) && (socket.channel() != null));
}

@Override
public void closeSocket() {
if (hasSocket()) {
try {
socket.channel().socket().close();
} catch (Exception e) {
log.log(Level.WARNING, "Error closing connection", e);
SocketChannel ch = channelForClose; // volatile read — guaranteed cross-thread visibility
if (ch != null) {
if (ch.isOpen()) {
String socketInfo = "";
try {
socketInfo = "local=" + ch.socket().getLocalPort() +
" remote=" + ch.socket().getRemoteSocketAddress();
} catch (Exception ignored) {}
try {
log.log(Level.INFO, "Closing socket channel: " + socketInfo);
ch.close();
} catch (Exception e) {
log.log(Level.WARNING, "Error closing socket channel: " + socketInfo, e);
}
}
} else {
if (hasSocket()) {
try {
socket.channel().close();
} catch (Exception e) {
log.log(Level.WARNING, "Error closing connection", e);
}
}
}
}
Expand Down
Loading