Skip to content

Conversation

@Ganeshsivakumar
Copy link

@Ganeshsivakumar Ganeshsivakumar commented Oct 25, 2025

Base Implementation for Java Native Remote Inference

addresses #36253


Thank you for your contribution! Follow this checklist to help us incorporate your contribution quickly and easily:

  • Mention the appropriate issue in your description (for example: addresses #123), if applicable. This will automatically add a link to the pull request in the issue. If you would like the issue to automatically close on merging the pull request, comment fixes #<ISSUE NUMBER> instead.
  • Update CHANGES.md with noteworthy changes.
  • If this contribution is large, please file an Apache Individual Contributor License Agreement.

See the Contributor Guide for more tips on how to make review process smoother.

To check the build health, please visit https://github.com/apache/beam/blob/master/.test-infra/BUILD_STATUS.md

GitHub Actions Tests Status (on master branch)

Build python source distribution and wheels
Python tests
Java tests
Go tests

See CI.md for more information about GitHub Actions CI or the workflows README to see a list of phrases to trigger workflows.

@github-actions github-actions bot added the java label Oct 25, 2025
@Ganeshsivakumar
Copy link
Author

Hi @jrmccluskey, would like to get your review. Thanks.

Copy link
Contributor

@jrmccluskey jrmccluskey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the bones here, although I think there needs to be a bit of work around the inputs being sequences of input types + having the outputs be (InputT, OutputT) tuples. Makes it easier to start from that point than to try and retrofit things once batching exists.

return builder().setParameters(modelParameters).build();
}

@Override
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the Python implementation (as well as the cross-language implementation in Java) we're generally trying to return input-output pairs to make it easier to process the results downstream.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated output format.

@Ganeshsivakumar Ganeshsivakumar marked this pull request as ready for review October 29, 2025 10:53
@Ganeshsivakumar Ganeshsivakumar changed the title [WIP] Java Native Remote Inference Java Native Remote Inference Nov 6, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Nov 6, 2025

Checks are failing. Will not request review until checks are succeeding. If you'd like to override that behavior, comment assign set of reviewers

@github-actions
Copy link
Contributor

Assigning reviewers:

R: @kennknowles for label java.

Note: If you would like to opt out of this review, comment assign to next reviewer.

Available commands:

  • stop reviewer notifications - opt out of the automated review tooling
  • remind me after tests pass - tag the comment author after tests pass
  • waiting on author - shift the attention set back to the author (any comment or push by the author will return the attention set to the reviewers)

The PR bot will only process comments in the main thread (not review comments).

@codecov
Copy link

codecov bot commented Nov 11, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 36.22%. Comparing base (312509f) to head (b3a2c67).
⚠️ Report is 295 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff            @@
##             master   #36623   +/-   ##
=========================================
  Coverage     36.22%   36.22%           
  Complexity     1666     1666           
=========================================
  Files          1058     1058           
  Lines        165088   165088           
  Branches       1190     1190           
=========================================
  Hits          59796    59796           
  Misses       103116   103116           
  Partials       2176     2176           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Contributor

@jrmccluskey jrmccluskey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a pure functionality standpoint I think this is in a good place. I'd like to see at least a naive retry wrapper around requests before merging this in, and I'm going to use the Gemini reviewer + some Beam committers with better knowledge of Java style to make sure this looks good

// Core Beam SDK
implementation(project(":sdks:java:core"))

implementation("com.openai:openai-java:4.3.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make more sense to scope the Open AI dependency for the Open AI model handler specifically? Otherwise any usage of a remote model handler, regardless of target service, would have a bunch of extra dependencies.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally, I thought about as well, model handlers can be their own separate modules.

Comment on lines +161 to +165
@ProcessElement
public void processElement(ProcessContext c) {
Iterable<PredictionResult<InputT, OutputT>> response = this.handler.request(c.element());
c.output(response);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next step here is to add some sort of try-catch logic around the request call. I'd recommend looking through the RequestResponseIO code (https://github.com/apache/beam/tree/ce1b1dcbc596d1e7c914ee0f7b0d48f2d2bf87e1/sdks/java/io/rrio/src/main/java/org/apache/beam/io/requestresponse) to get an idea of what that looks like.

@jrmccluskey
Copy link
Contributor

@gemini-code-assist review

@jrmccluskey jrmccluskey requested a review from Abacn November 13, 2025 14:50
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new RemoteInference transform for the Java SDK, providing a framework for making calls to external ML services, along with an implementation for OpenAI. The overall structure is well-designed, with a clear separation of concerns between the generic framework and the model-specific handlers. The tests are also comprehensive, covering unit, integration, and failure scenarios.

My review includes several suggestions to improve performance, correctness, and type safety. Key points include:

  • Fixing a critical dependency issue in build.gradle.kts.
  • Implementing batching to improve performance by reducing remote calls.
  • Reusing the ObjectMapper instance in OpenAIModelHandler for better performance.
  • Improving dependency scoping for logging libraries.
  • Correcting an invalid model name in an integration test.

Addressing these points will make the new feature more robust and performant. Great work on this new capability!

compileOnly("com.google.auto.value:auto-value-annotations:1.11.0")
compileOnly("org.checkerframework:checker-qual:3.42.0")
annotationProcessor("com.google.auto.value:auto-value:1.11.0")
implementation("com.fasterxml.jackson.core:jackson-core:2.20.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The specified version 2.20.0 for com.fasterxml.jackson.core:jackson-core does not appear to be a valid published version and will likely cause build failures. Please use a valid version. For example, 2.17.1 is a recent stable version.

  implementation("com.fasterxml.jackson.core:jackson-core:2.17.1")

implementation("com.fasterxml.jackson.core:jackson-core:2.20.0")
implementation("org.apache.beam:beam-vendor-guava-32_1_2-jre:0.1")
implementation("org.slf4j:slf4j-api:2.0.9")
implementation("org.slf4j:slf4j-simple:2.0.9")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The slf4j-simple dependency should be scoped to testImplementation instead of implementation. Including a concrete logging implementation like slf4j-simple in the implementation scope can lead to classpath conflicts in downstream projects that use this module, as they may have their own preferred logging framework. It's a best practice for libraries to only depend on slf4j-api.

  testImplementation("org.slf4j:slf4j-simple:2.0.9")

Comment on lines +116 to +124
return input
.apply("WrapInputInList", MapElements.via(new SimpleFunction<InputT, List<InputT>>() {
@Override
public List<InputT> apply(InputT element) {
return Collections.singletonList(element);
}
}))
// Pass the list to the inference function
.apply("RemoteInference", ParDo.of(new RemoteInferenceFn<InputT, OutputT>(this)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation processes each input element individually by wrapping it in a singleton list. This results in a separate remote inference call for every element, which is highly inefficient and will lead to poor performance, especially for large datasets. The BaseModelHandler#request method already accepts a List<InputT>, suggesting that batching is intended.

To improve performance, you should introduce batching before the ParDo transform. You can use org.apache.beam.sdk.transforms.GroupIntoBatches to group elements into batches of a configurable size. This will significantly reduce the number of remote calls.

For example, you could add a batchSize parameter to the Invoke transform and then use it like this:

// First, add a key to each element
input.apply(WithKeys.of("key"))
     // Then, group elements into batches
     .apply(GroupIntoBatches.ofSize(batchSize))
     // Then, get the values (the batches)
     .apply(Values.create())
     // Finally, perform the remote inference on each batch
     .apply("RemoteInference", ParDo.of(new RemoteInferenceFn<InputT, OutputT>(this)));

Comment on lines +98 to +99
String inputBatch = new ObjectMapper()
.writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

A new ObjectMapper instance is created for every call to the request method. ObjectMapper is a heavy object to create, and it is thread-safe. Creating it repeatedly in a hot path like this can negatively impact performance.

You should create the ObjectMapper instance once and reuse it. A good place to initialize it would be in the createClient method and store it in a transient field.

  1. Add a field to OpenAIModelHandler:
private transient ObjectMapper objectMapper;
  1. Initialize it in createClient:
@Override
public void createClient(OpenAIModelParameters parameters) {
  this.modelParameters = parameters;
  this.client = OpenAIOkHttpClient.builder()
    .apiKey(this.modelParameters.getApiKey())
    .build();
  this.objectMapper = new ObjectMapper();
}
  1. Use the field in request:
Suggested change
String inputBatch = new ObjectMapper()
.writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList());
String inputBatch = objectMapper
.writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList());

.handler(OpenAIModelHandler.class)
.withParameters(OpenAIModelParameters.builder()
.apiKey(apiKey)
.modelName("gpt-5")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The integration test testWithDifferentModel uses the model name "gpt-5". This model is not currently available, and using it will cause the integration test to fail with an "invalid model" error from the OpenAI API. Please use a valid and available model name to ensure the test can run successfully. For example, you could use another real model like gpt-4-turbo.

Suggested change
.modelName("gpt-5")
.modelName("gpt-4-turbo")

* }</pre>
*
*/
@SuppressWarnings({ "rawtypes", "unchecked" })
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The class is suppressed with @SuppressWarnings({ "rawtypes", "unchecked" }) because BaseModelHandler is used as a raw type. This reduces type safety and can hide potential class cast exceptions at runtime.

For example, Invoke.handler() returns Class<? extends BaseModelHandler>, and RemoteInferenceFn.handler is a raw BaseModelHandler. This means the compiler cannot verify that the BaseModelParameters and BaseInput/BaseResponse types are compatible between the Invoke transform and the BaseModelHandler implementation.

While fixing this might require some significant refactoring of the generics, it would make the framework more robust. A potential direction could be to include the BaseModelParameters type in the Invoke transform's generics, like Invoke<InputT, OutputT, ParamT extends BaseModelParameters>. This would allow for stronger type checking throughout the implementation.

implements BaseModelHandler<OpenAIModelParameters, OpenAIModelInput, OpenAIModelResponse> {

private transient OpenAIClient client;
private transient StructuredResponseCreateParams<StructuredInputOutput> clientParams;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The field clientParams is declared as a transient instance variable but it is only assigned and used within the request method. This makes the code harder to reason about, as it suggests clientParams holds state across method calls, which it doesn't. It should be a local variable within the request method.

Comment on lines +102 to +107
this.clientParams = ResponseCreateParams.builder()
.model(modelParameters.getModelName())
.input(inputBatch)
.text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO)
.instructions(modelParameters.getInstructionPrompt())
.build();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The clientParams is being assigned to an instance field, but it's only used locally within the request method. It should be a local variable. This also requires removing the clientParams field from the class.

Suggested change
this.clientParams = ResponseCreateParams.builder()
.model(modelParameters.getModelName())
.input(inputBatch)
.text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO)
.instructions(modelParameters.getInstructionPrompt())
.build();
StructuredResponseCreateParams<StructuredInputOutput> clientParams = ResponseCreateParams.builder()
.model(modelParameters.getModelName())
.input(inputBatch)
.text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO)
.instructions(modelParameters.getInstructionPrompt())
.build();

*
* @see OpenAIModelHandler
*/
public class OpenAIModelParameters implements BaseModelParameters {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The classes OpenAIModelParameters, OpenAIModelInput, and OpenAIModelResponse do not override equals() and hashCode(). This can lead to unexpected behavior when these objects are used in collections (like Set or as keys in a Map) or in tests that rely on object equality. The test classes you've written for RemoteInferenceTest correctly implement these methods, and the production classes should as well.

For OpenAIModelParameters, you can add the following:

@Override
public boolean equals(Object o) {
  if (this == o) return true;
  if (o == null || getClass() != o.getClass()) return false;
  OpenAIModelParameters that = (OpenAIModelParameters) o;
  return java.util.Objects.equals(apiKey, that.apiKey) &&
         java.util.Objects.equals(modelName, that.modelName) &&
         java.util.Objects.equals(instructionPrompt, that.instructionPrompt);
}

@Override
public int hashCode() {
  return java.util.Objects.hash(apiKey, modelName, instructionPrompt);
}

Similar implementations should be added to OpenAIModelInput and OpenAIModelResponse.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants