-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Java Native Remote Inference #36623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Java Native Remote Inference #36623
Conversation
|
Hi @jrmccluskey, would like to get your review. Thanks. |
jrmccluskey
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the bones here, although I think there needs to be a bit of work around the inputs being sequences of input types + having the outputs be (InputT, OutputT) tuples. Makes it easier to start from that point than to try and retrofit things once batching exists.
...nference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAiModelHandler.java
Outdated
Show resolved
Hide resolved
| return builder().setParameters(modelParameters).build(); | ||
| } | ||
|
|
||
| @Override |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the Python implementation (as well as the cross-language implementation in Java) we're generally trying to return input-output pairs to make it easier to process the results downstream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, updated output format.
...l/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/base/BatchConfig.java
Outdated
Show resolved
Hide resolved
|
Checks are failing. Will not request review until checks are succeeding. If you'd like to override that behavior, comment |
|
Assigning reviewers: R: @kennknowles for label java. Note: If you would like to opt out of this review, comment Available commands:
The PR bot will only process comments in the main thread (not review comments). |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #36623 +/- ##
=========================================
Coverage 36.22% 36.22%
Complexity 1666 1666
=========================================
Files 1058 1058
Lines 165088 165088
Branches 1190 1190
=========================================
Hits 59796 59796
Misses 103116 103116
Partials 2176 2176 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
jrmccluskey
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a pure functionality standpoint I think this is in a good place. I'd like to see at least a naive retry wrapper around requests before merging this in, and I'm going to use the Gemini reviewer + some Beam committers with better knowledge of Java style to make sure this looks good
| // Core Beam SDK | ||
| implementation(project(":sdks:java:core")) | ||
|
|
||
| implementation("com.openai:openai-java:4.3.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make more sense to scope the Open AI dependency for the Open AI model handler specifically? Otherwise any usage of a remote model handler, regardless of target service, would have a bunch of extra dependencies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally, I thought about as well, model handlers can be their own separate modules.
| @ProcessElement | ||
| public void processElement(ProcessContext c) { | ||
| Iterable<PredictionResult<InputT, OutputT>> response = this.handler.request(c.element()); | ||
| c.output(response); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The next step here is to add some sort of try-catch logic around the request call. I'd recommend looking through the RequestResponseIO code (https://github.com/apache/beam/tree/ce1b1dcbc596d1e7c914ee0f7b0d48f2d2bf87e1/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse) to get an idea of what that looks like.
|
@gemini-code-assist review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new RemoteInference transform for the Java SDK, providing a framework for making calls to external ML services, along with an implementation for OpenAI. The overall structure is well-designed, with a clear separation of concerns between the generic framework and the model-specific handlers. The tests are also comprehensive, covering unit, integration, and failure scenarios.
My review includes several suggestions to improve performance, correctness, and type safety. Key points include:
- Fixing a critical dependency issue in
build.gradle.kts. - Implementing batching to improve performance by reducing remote calls.
- Reusing the
ObjectMapperinstance inOpenAIModelHandlerfor better performance. - Improving dependency scoping for logging libraries.
- Correcting an invalid model name in an integration test.
Addressing these points will make the new feature more robust and performant. Great work on this new capability!
| compileOnly("com.google.auto.value:auto-value-annotations:1.11.0") | ||
| compileOnly("org.checkerframework:checker-qual:3.42.0") | ||
| annotationProcessor("com.google.auto.value:auto-value:1.11.0") | ||
| implementation("com.fasterxml.jackson.core:jackson-core:2.20.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| implementation("com.fasterxml.jackson.core:jackson-core:2.20.0") | ||
| implementation("org.apache.beam:beam-vendor-guava-32_1_2-jre:0.1") | ||
| implementation("org.slf4j:slf4j-api:2.0.9") | ||
| implementation("org.slf4j:slf4j-simple:2.0.9") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The slf4j-simple dependency should be scoped to testImplementation instead of implementation. Including a concrete logging implementation like slf4j-simple in the implementation scope can lead to classpath conflicts in downstream projects that use this module, as they may have their own preferred logging framework. It's a best practice for libraries to only depend on slf4j-api.
testImplementation("org.slf4j:slf4j-simple:2.0.9")| return input | ||
| .apply("WrapInputInList", MapElements.via(new SimpleFunction<InputT, List<InputT>>() { | ||
| @Override | ||
| public List<InputT> apply(InputT element) { | ||
| return Collections.singletonList(element); | ||
| } | ||
| })) | ||
| // Pass the list to the inference function | ||
| .apply("RemoteInference", ParDo.of(new RemoteInferenceFn<InputT, OutputT>(this))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation processes each input element individually by wrapping it in a singleton list. This results in a separate remote inference call for every element, which is highly inefficient and will lead to poor performance, especially for large datasets. The BaseModelHandler#request method already accepts a List<InputT>, suggesting that batching is intended.
To improve performance, you should introduce batching before the ParDo transform. You can use org.apache.beam.sdk.transforms.GroupIntoBatches to group elements into batches of a configurable size. This will significantly reduce the number of remote calls.
For example, you could add a batchSize parameter to the Invoke transform and then use it like this:
// First, add a key to each element
input.apply(WithKeys.of("key"))
// Then, group elements into batches
.apply(GroupIntoBatches.ofSize(batchSize))
// Then, get the values (the batches)
.apply(Values.create())
// Finally, perform the remote inference on each batch
.apply("RemoteInference", ParDo.of(new RemoteInferenceFn<InputT, OutputT>(this)));| String inputBatch = new ObjectMapper() | ||
| .writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A new ObjectMapper instance is created for every call to the request method. ObjectMapper is a heavy object to create, and it is thread-safe. Creating it repeatedly in a hot path like this can negatively impact performance.
You should create the ObjectMapper instance once and reuse it. A good place to initialize it would be in the createClient method and store it in a transient field.
- Add a field to
OpenAIModelHandler:
private transient ObjectMapper objectMapper;- Initialize it in
createClient:
@Override
public void createClient(OpenAIModelParameters parameters) {
this.modelParameters = parameters;
this.client = OpenAIOkHttpClient.builder()
.apiKey(this.modelParameters.getApiKey())
.build();
this.objectMapper = new ObjectMapper();
}- Use the field in
request:
| String inputBatch = new ObjectMapper() | |
| .writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList()); | |
| String inputBatch = objectMapper | |
| .writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList()); |
| .handler(OpenAIModelHandler.class) | ||
| .withParameters(OpenAIModelParameters.builder() | ||
| .apiKey(apiKey) | ||
| .modelName("gpt-5") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The integration test testWithDifferentModel uses the model name "gpt-5". This model is not currently available, and using it will cause the integration test to fail with an "invalid model" error from the OpenAI API. Please use a valid and available model name to ensure the test can run successfully. For example, you could use another real model like gpt-4-turbo.
| .modelName("gpt-5") | |
| .modelName("gpt-4-turbo") |
| * }</pre> | ||
| * | ||
| */ | ||
| @SuppressWarnings({ "rawtypes", "unchecked" }) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class is suppressed with @SuppressWarnings({ "rawtypes", "unchecked" }) because BaseModelHandler is used as a raw type. This reduces type safety and can hide potential class cast exceptions at runtime.
For example, Invoke.handler() returns Class<? extends BaseModelHandler>, and RemoteInferenceFn.handler is a raw BaseModelHandler. This means the compiler cannot verify that the BaseModelParameters and BaseInput/BaseResponse types are compatible between the Invoke transform and the BaseModelHandler implementation.
While fixing this might require some significant refactoring of the generics, it would make the framework more robust. A potential direction could be to include the BaseModelParameters type in the Invoke transform's generics, like Invoke<InputT, OutputT, ParamT extends BaseModelParameters>. This would allow for stronger type checking throughout the implementation.
| implements BaseModelHandler<OpenAIModelParameters, OpenAIModelInput, OpenAIModelResponse> { | ||
|
|
||
| private transient OpenAIClient client; | ||
| private transient StructuredResponseCreateParams<StructuredInputOutput> clientParams; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The field clientParams is declared as a transient instance variable but it is only assigned and used within the request method. This makes the code harder to reason about, as it suggests clientParams holds state across method calls, which it doesn't. It should be a local variable within the request method.
| this.clientParams = ResponseCreateParams.builder() | ||
| .model(modelParameters.getModelName()) | ||
| .input(inputBatch) | ||
| .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) | ||
| .instructions(modelParameters.getInstructionPrompt()) | ||
| .build(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The clientParams is being assigned to an instance field, but it's only used locally within the request method. It should be a local variable. This also requires removing the clientParams field from the class.
| this.clientParams = ResponseCreateParams.builder() | |
| .model(modelParameters.getModelName()) | |
| .input(inputBatch) | |
| .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) | |
| .instructions(modelParameters.getInstructionPrompt()) | |
| .build(); | |
| StructuredResponseCreateParams<StructuredInputOutput> clientParams = ResponseCreateParams.builder() | |
| .model(modelParameters.getModelName()) | |
| .input(inputBatch) | |
| .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) | |
| .instructions(modelParameters.getInstructionPrompt()) | |
| .build(); |
| * | ||
| * @see OpenAIModelHandler | ||
| */ | ||
| public class OpenAIModelParameters implements BaseModelParameters { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The classes OpenAIModelParameters, OpenAIModelInput, and OpenAIModelResponse do not override equals() and hashCode(). This can lead to unexpected behavior when these objects are used in collections (like Set or as keys in a Map) or in tests that rely on object equality. The test classes you've written for RemoteInferenceTest correctly implement these methods, and the production classes should as well.
For OpenAIModelParameters, you can add the following:
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
OpenAIModelParameters that = (OpenAIModelParameters) o;
return java.util.Objects.equals(apiKey, that.apiKey) &&
java.util.Objects.equals(modelName, that.modelName) &&
java.util.Objects.equals(instructionPrompt, that.instructionPrompt);
}
@Override
public int hashCode() {
return java.util.Objects.hash(apiKey, modelName, instructionPrompt);
}Similar implementations should be added to OpenAIModelInput and OpenAIModelResponse.
Base Implementation for Java Native Remote Inference
addresses #36253
Thank you for your contribution! Follow this checklist to help us incorporate your contribution quickly and easily:
addresses #123), if applicable. This will automatically add a link to the pull request in the issue. If you would like the issue to automatically close on merging the pull request, commentfixes #<ISSUE NUMBER>instead.CHANGES.mdwith noteworthy changes.See the Contributor Guide for more tips on how to make review process smoother.
To check the build health, please visit https://github.com/apache/beam/blob/master/.test-infra/BUILD_STATUS.md
GitHub Actions Tests Status (on master branch)
See CI.md for more information about GitHub Actions CI or the workflows README to see a list of phrases to trigger workflows.