Skip to content

Commit cfdeb8f

Browse files
committed
Fix coalesce to collect index from all nodes and extract ShuffleNodeClient utility
Signed-off-by: Sotaro Hikita <bering1814@gmail.com>
1 parent 2ceca2c commit cfdeb8f

4 files changed

Lines changed: 315 additions & 123 deletions

File tree

data-prepper-plugins/iceberg-source/src/main/java/org/opensearch/dataprepper/plugins/source/iceberg/leader/LeaderScheduler.java

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.opensearch.dataprepper.plugins.source.iceberg.coordination.state.ShuffleReadProgressState;
3333
import org.opensearch.dataprepper.plugins.source.iceberg.coordination.state.ShuffleWriteProgressState;
3434
import org.opensearch.dataprepper.plugins.source.iceberg.shuffle.ShuffleConfig;
35+
import org.opensearch.dataprepper.plugins.source.iceberg.shuffle.ShuffleNodeClient;
3536
import org.opensearch.dataprepper.plugins.source.iceberg.shuffle.ShufflePartitionCoalescer;
3637
import org.opensearch.dataprepper.plugins.source.iceberg.shuffle.ShuffleStorage;
3738
import org.slf4j.Logger;
@@ -396,9 +397,23 @@ private boolean processShuffleSnapshot(final String tableName, final long snapsh
396397
return false;
397398
}
398399

399-
// Barrier: collect index data and coalesce
400+
// Barrier: collect index data from all nodes and coalesce
401+
// Read shuffle write locations from GlobalState first (need node addresses to fetch remote indexes)
402+
final Optional<EnhancedSourcePartition> locationPartition = sourceCoordinator.getPartition(locationKey);
403+
final Map<String, Object> locationMap = locationPartition.map(enhancedSourcePartition -> ((GlobalState) enhancedSourcePartition).getProgressState().orElse(Map.of())).orElseGet(Map::of);
404+
405+
final List<String> completedTaskIds = new ArrayList<>();
406+
final List<String> completedNodeAddresses = new ArrayList<>();
407+
for (final Map.Entry<String, Object> entry : locationMap.entrySet()) {
408+
completedTaskIds.add(entry.getKey());
409+
completedNodeAddresses.add(String.valueOf(entry.getValue()));
410+
}
411+
LOG.info("Collected {} shuffle write locations for snapshot {}", completedTaskIds.size(), snapshotId);
412+
400413
final int numPartitions = shuffleConfig.getPartitions();
401-
final long[] partitionSizes = shuffleStorage.getPartitionSizes(snapshotIdStr, numPartitions);
414+
final ShuffleNodeClient client = new ShuffleNodeClient(shuffleConfig);
415+
final long[] partitionSizes = client.collectPartitionSizes(
416+
shuffleStorage, snapshotIdStr, completedTaskIds, completedNodeAddresses, numPartitions);
402417

403418
final ShufflePartitionCoalescer coalescer =
404419
new ShufflePartitionCoalescer(shuffleConfig.getTargetPartitionSizeBytes());
@@ -410,19 +425,6 @@ private boolean processShuffleSnapshot(final String tableName, final long snapsh
410425
return true;
411426
}
412427

413-
// Read shuffle write locations from GlobalState.
414-
// SHUFFLE_WRITE workers write their location to GlobalState on completion.
415-
final Optional<EnhancedSourcePartition> locationPartition = sourceCoordinator.getPartition(locationKey);
416-
final Map<String, Object> locationMap = locationPartition.map(enhancedSourcePartition -> ((GlobalState) enhancedSourcePartition).getProgressState().orElse(Map.of())).orElseGet(Map::of);
417-
418-
final List<String> completedTaskIds = new ArrayList<>();
419-
final List<String> completedNodeAddresses = new ArrayList<>();
420-
for (final Map.Entry<String, Object> entry : locationMap.entrySet()) {
421-
completedTaskIds.add(entry.getKey());
422-
completedNodeAddresses.add(String.valueOf(entry.getValue()));
423-
}
424-
LOG.info("Collected {} shuffle write locations for snapshot {}", completedTaskIds.size(), snapshotId);
425-
426428
// Phase 2: Create SHUFFLE_READ tasks
427429
final String readCompletionKey = SNAPSHOT_COMPLETION_PREFIX + "sr-" + snapshotId;
428430
sourceCoordinator.createPartition(new GlobalState(readCompletionKey,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* The OpenSearch Contributors require contributions made to
6+
* this file be licensed under the Apache-2.0 license or a
7+
* compatible open source license.
8+
*
9+
*/
10+
11+
package org.opensearch.dataprepper.plugins.source.iceberg.shuffle;
12+
13+
import org.slf4j.Logger;
14+
import org.slf4j.LoggerFactory;
15+
16+
import javax.net.ssl.SSLContext;
17+
import javax.net.ssl.TrustManager;
18+
import javax.net.ssl.X509TrustManager;
19+
import java.net.InetAddress;
20+
import java.net.NetworkInterface;
21+
import java.net.URI;
22+
import java.net.http.HttpClient;
23+
import java.net.http.HttpRequest;
24+
import java.net.http.HttpResponse;
25+
import java.nio.ByteBuffer;
26+
import java.security.cert.X509Certificate;
27+
import java.time.Duration;
28+
import java.util.List;
29+
30+
/**
31+
* Shared HTTP client utilities for pulling shuffle index and data from remote nodes.
32+
* Used by both LeaderScheduler (index collection for coalesce) and ChangelogWorker (data pull for SHUFFLE_READ).
33+
*/
34+
public class ShuffleNodeClient {
35+
36+
private static final Logger LOG = LoggerFactory.getLogger(ShuffleNodeClient.class);
37+
private static final int MAX_RETRIES = 3;
38+
39+
private final HttpClient httpClient;
40+
private final String scheme;
41+
private final int port;
42+
43+
public ShuffleNodeClient(final ShuffleConfig config) {
44+
this.httpClient = buildHttpClient(config);
45+
this.scheme = config.isSsl() ? "https" : "http";
46+
this.port = config.getServerPort();
47+
}
48+
49+
public long[] pullIndex(final String nodeAddress, final String snapshotId, final String taskId) throws Exception {
50+
final byte[] body = executeWithRetry(
51+
String.format("%s://%s:%d/shuffle/%s/%s/index", scheme, nodeAddress, port, snapshotId, taskId),
52+
Duration.ofSeconds(10),
53+
"index from " + nodeAddress);
54+
final ByteBuffer buf = ByteBuffer.wrap(body);
55+
final long[] offsets = new long[body.length / Long.BYTES];
56+
for (int i = 0; i < offsets.length; i++) {
57+
offsets[i] = buf.getLong();
58+
}
59+
return offsets;
60+
}
61+
62+
public byte[] pullData(final String nodeAddress, final String snapshotId, final String taskId,
63+
final long offset, final int length) throws Exception {
64+
return executeWithRetry(
65+
String.format("%s://%s:%d/shuffle/%s/%s/data?offset=%d&length=%d",
66+
scheme, nodeAddress, port, snapshotId, taskId, offset, length),
67+
Duration.ofSeconds(30),
68+
"data from " + nodeAddress);
69+
}
70+
71+
private byte[] executeWithRetry(final String url, final Duration timeout, final String description) throws Exception {
72+
for (int attempt = 1; attempt <= MAX_RETRIES; attempt++) {
73+
try {
74+
final HttpResponse<byte[]> response = httpClient.send(
75+
HttpRequest.newBuilder(URI.create(url)).GET().timeout(timeout).build(),
76+
HttpResponse.BodyHandlers.ofByteArray());
77+
if (response.statusCode() == 200) {
78+
return response.body();
79+
}
80+
LOG.warn("HTTP pull failed for {}: status={} attempt={}/{}", description, response.statusCode(), attempt, MAX_RETRIES);
81+
} catch (final Exception e) {
82+
LOG.warn("HTTP pull failed for {}: attempt={}/{}", description, attempt, MAX_RETRIES, e);
83+
}
84+
if (attempt < MAX_RETRIES) {
85+
Thread.sleep(1000L * attempt);
86+
}
87+
}
88+
throw new RuntimeException("Failed to pull " + description + " after " + MAX_RETRIES + " retries");
89+
}
90+
91+
public static boolean isLocalAddress(final String address) {
92+
try {
93+
final InetAddress inetAddress = InetAddress.getByName(address);
94+
if (inetAddress.isAnyLocalAddress() || inetAddress.isLoopbackAddress()) {
95+
return true;
96+
}
97+
return NetworkInterface.getByInetAddress(inetAddress) != null;
98+
} catch (final Exception e) {
99+
return false;
100+
}
101+
}
102+
103+
public static String resolveLocalAddress() {
104+
try {
105+
return InetAddress.getLocalHost().getHostName();
106+
} catch (final Exception e) {
107+
throw new RuntimeException("Failed to resolve local host name", e);
108+
}
109+
}
110+
111+
/**
112+
* Collects partition sizes from all nodes by reading index files.
113+
* Local tasks are read from disk, remote tasks are fetched via HTTP.
114+
*/
115+
public long[] collectPartitionSizes(final ShuffleStorage shuffleStorage,
116+
final String snapshotId,
117+
final List<String> taskIds,
118+
final List<String> nodeAddresses,
119+
final int numPartitions) {
120+
final long[] sizes = new long[numPartitions];
121+
for (int i = 0; i < taskIds.size(); i++) {
122+
final String taskId = taskIds.get(i);
123+
final String nodeAddress = nodeAddresses.get(i);
124+
try {
125+
final long[] offsets;
126+
if (isLocalAddress(nodeAddress)) {
127+
try (var reader = shuffleStorage.createReader(snapshotId, taskId)) {
128+
offsets = reader.readIndex();
129+
}
130+
} else {
131+
offsets = pullIndex(nodeAddress, snapshotId, taskId);
132+
}
133+
for (int p = 0; p < numPartitions && p + 1 < offsets.length; p++) {
134+
sizes[p] += offsets[p + 1] - offsets[p];
135+
}
136+
} catch (final Exception e) {
137+
LOG.warn("Failed to read index for task {} from node {}, skipping for coalesce", taskId, nodeAddress, e);
138+
}
139+
}
140+
return sizes;
141+
}
142+
143+
private static HttpClient buildHttpClient(final ShuffleConfig config) {
144+
final HttpClient.Builder builder = HttpClient.newBuilder()
145+
.connectTimeout(Duration.ofSeconds(10));
146+
if (config.isSsl() && config.isSslInsecureDisableVerification()) {
147+
try {
148+
final TrustManager[] trustAllCerts = new TrustManager[]{
149+
new X509TrustManager() {
150+
public X509Certificate[] getAcceptedIssuers() { return new X509Certificate[0]; }
151+
public void checkClientTrusted(X509Certificate[] certs, String authType) {}
152+
public void checkServerTrusted(X509Certificate[] certs, String authType) {}
153+
}
154+
};
155+
final SSLContext sslContext = SSLContext.getInstance("TLS");
156+
sslContext.init(null, trustAllCerts, new java.security.SecureRandom());
157+
builder.sslContext(sslContext);
158+
} catch (final Exception e) {
159+
throw new RuntimeException("Failed to configure insecure SSL context", e);
160+
}
161+
}
162+
return builder.build();
163+
}
164+
}

0 commit comments

Comments
 (0)