Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ public class Client implements AutoCloseable {
private static final ThreadLocal<Integer> retryCount = new ThreadLocal<Integer>();
private static final ThreadLocal<Object> EXTERNAL_CALL_HANDLER
= new ThreadLocal<>();
public static final ThreadLocal<CompletableFuture<Object>> CALL_FUTURE_THREAD_LOCAL
= new ThreadLocal<>();
private static final ThreadLocal<AsyncGet<? extends Writable, IOException>>
ASYNC_RPC_RESPONSE = new ThreadLocal<>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already a field ASYNC_RPC_RESPONSE. Please replace it with CompletableFuture instead of adding a new field.

Copy link
Member Author

@KeeProMise KeeProMise Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@szetszwo Hi, thanks for you review! Do you mean that I need to delete this Client.ASYNC_RPC_RESPONSE, and then use CompletableFuture in all places where ASYNC_RPC_RESPONSE is used? Such modification may require repairing many unit tests, because ASYNC_RPC_RESPONSE is used by many unit tests, and ProtobufRpcEngine2 and ProtobufRpcEngine also use ASYNC_RPC_RESPONSE. If I delete ASYNC_RPC_RESPONSE, then how should the ASYNC_RETURN_MESSAGE attribute of ProtobufRpcEngine2 and ProtobufRpcEngine be processed? Is it also needed to removed and use CompletableFuture? which may affect more unit tests and code.

private static final ThreadLocal<Boolean> asynchronousMode =
Expand Down Expand Up @@ -283,6 +285,7 @@ static class Call {
boolean done; // true when call is done
private final Object externalHandler;
private AlignmentContext alignmentContext;
private final CompletableFuture<Object> completableFuture;

private Call(RPC.RpcKind rpcKind, Writable param) {
this.rpcKind = rpcKind;
Expand All @@ -304,6 +307,8 @@ private Call(RPC.RpcKind rpcKind, Writable param) {
}

this.externalHandler = EXTERNAL_CALL_HANDLER.get();
this.completableFuture = CALL_FUTURE_THREAD_LOCAL.get();
CALL_FUTURE_THREAD_LOCAL.remove();
}

@Override
Expand All @@ -322,6 +327,9 @@ protected synchronized void callComplete() {
externalHandler.notify();
}
}
if (completableFuture != null) {
completableFuture.complete(this);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.util.Time;
import org.apache.hadoop.util.concurrent.AsyncGetFuture;
import org.junit.Assert;
import org.junit.Before;
Expand All @@ -38,13 +39,17 @@
import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

public class TestAsyncIPC {

Expand Down Expand Up @@ -137,6 +142,77 @@ void assertReturnValues(long timeout, TimeUnit unit)
}
}

/**
* For testing the asynchronous calls of the RPC client
* implemented with CompletableFuture.
*/
static class AsyncCompletableFutureCaller extends Thread {
private final Client client;
private final InetSocketAddress server;
private final int count;
private final List<CompletableFuture<Object>> completableFutures = new ArrayList<>();
private final List<Long> expectedValues = new ArrayList<>();

AsyncCompletableFutureCaller(Client client, InetSocketAddress server, int count) {
this.client = client;
this.server = server;
this.count = count;
setName("Async CompletableFuture Caller");
}

@Override
public void run() {
// Set the RPC client to use asynchronous mode.
Client.setAsynchronousMode(true);
long startTime = Time.monotonicNow();
try {
for (int i = 0; i < count; i++) {
final long param = TestIPC.RANDOM.nextLong();
// Set the CompletableFuture object for the current Client.Call.
CompletableFuture<Object> completableFuture = new CompletableFuture<>();
Client.CALL_FUTURE_THREAD_LOCAL.set(completableFuture);
// Execute asynchronous call.
TestIPC.call(client, param, server, conf);
expectedValues.add(param);
// After the call is completed, the response thread
// (currently the Client.connection thread) retrieves the response
// using the AsyncGetFuture<Writable, IOException> object.
AsyncGetFuture<Writable, IOException> asyncRpcResponse = getAsyncRpcResponseFuture();
completableFuture = completableFuture.thenApply(call -> {
LOG.info("[{}] Async response for {}", Thread.currentThread().getName(), call);
assertTrue(Thread.currentThread().getName().contains("connection"));
try {
// Since the current call has already been completed,
// this method does not need to block.
return asyncRpcResponse.get();
} catch (Exception e) {
throw new CompletionException(e);
}
});
completableFutures.add(completableFuture);
}
// Since the run method is asynchronous,
// it does not need to wait for a response after sending a request,
// so the time taken by the run method is less than count * 100
// (where 100 is the time taken by the server to process a request).
long cost = Time.monotonicNow() - startTime;
assertTrue(cost < count * 100);
LOG.info("[{}] run cost {}ms", Thread.currentThread().getName(), cost);
} catch (Exception e) {
fail();
}
}

public void assertReturnValues()
throws InterruptedException, ExecutionException {
for (int i = 0; i < count; i++) {
LongWritable value = (LongWritable) completableFutures.get(i).get();
Assert.assertEquals("call" + i + " failed.",
expectedValues.get(i).longValue(), value.get());
}
}
}

static class AsyncLimitlCaller extends Thread {
private Client client;
private InetSocketAddress server;
Expand Down Expand Up @@ -538,4 +614,36 @@ public void run() {
assertEquals(startID + i, callIds.get(i).intValue());
}
}

@Test(timeout = 60000)
public void testAsyncCallWithCompletableFuture() throws IOException,
InterruptedException, ExecutionException {
// Override client to store the call id
final Client client = new Client(LongWritable.class, conf);

// Construct an RPC server, which includes a handler thread.
final TestServer server = new TestIPC.TestServer(1, false, conf);
server.callListener = () -> {
try {
// The server requires at least 100 milliseconds to process a request.
Thread.sleep(100);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
};

try {
InetSocketAddress addr = NetUtils.getConnectAddress(server);
server.start();
// Send 10 asynchronous requests.
final AsyncCompletableFutureCaller caller =
new AsyncCompletableFutureCaller(client, addr, 10);
caller.run();
// Check if the values returned by the asynchronous call meet the expected values.
caller.assertReturnValues();
} finally {
client.stop();
server.stop();
}
}
}