Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e94eda7
Add the FileAwareFactoryFn and the KerberosConsumerFactoryFn classes …
fozzie15 Aug 27, 2025
90e9f91
Revert "Add the FileAwareFactoryFn and the KerberosConsumerFactoryFn …
fozzie15 Aug 27, 2025
68230af
Add tests for file aware factory fn
fozzie15 Aug 27, 2025
4084a4a
Add changes to the build and integration files for manual testing. Be…
fozzie15 Aug 29, 2025
d866b94
Migrate to a new module such that kafka remains GCP Agnostic.
fozzie15 Sep 9, 2025
5f49b64
Clean up classes for PR review
fozzie15 Sep 9, 2025
220ba3d
Move the existing module files to the extensions repo. This module wi…
fozzie15 Sep 17, 2025
48a084e
Modify the base class to use GCS client instead of GCS FileSystems. T…
fozzie15 Sep 23, 2025
4e34e0c
Migrate to a new module such that kafka remains GCP Agnostic.
fozzie15 Sep 9, 2025
ba188e1
Move the existing module files to the extensions repo. This module wi…
fozzie15 Sep 17, 2025
43d731c
Add plumbing for python use case.
fozzie15 Sep 19, 2025
470a751
Remove accidentally committed python modules
fozzie15 Sep 29, 2025
78b8a07
Trigger CI build
fozzie15 Nov 3, 2025
8c389bc
Clean up typing.
fozzie15 Nov 5, 2025
d77f8f7
Add the FileAwareFactoryFn and the KerberosConsumerFactoryFn classes …
fozzie15 Aug 27, 2025
36310a9
Revert "Add the FileAwareFactoryFn and the KerberosConsumerFactoryFn …
fozzie15 Aug 27, 2025
4e8d9f6
Migrate to a new module such that kafka remains GCP Agnostic.
fozzie15 Sep 9, 2025
d877dc5
Move the existing module files to the extensions repo. This module wi…
fozzie15 Sep 17, 2025
502e2fc
Add plumbing for python use case.
fozzie15 Sep 19, 2025
59121d2
finish python changes with a working test
fozzie15 Sep 23, 2025
9b0c741
Add support for the custom ConsumerFactoryFn in KafkaTable
fozzie15 Oct 9, 2025
fafbb65
Clean up accidental python files included in the PR.
fozzie15 Oct 28, 2025
106dd69
remove unused files
fozzie15 Nov 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package org.apache.beam.sdk.extensions.kafka.factories;

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
Expand All @@ -15,8 +17,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.kafka.factories;

import com.google.cloud.secretmanager.v1.AccessSecretVersionResponse;
import com.google.cloud.secretmanager.v1.SecretManagerServiceClient;
import com.google.cloud.secretmanager.v1.SecretVersionName;
Expand Down Expand Up @@ -207,6 +207,7 @@ protected byte[] getSecretWithCache(String secretId) {
* @return a string with all instances of external paths converted to the local paths where the
* files sit.
*/

private String replacePathWithLocal(String externalPath) throws IOException {
String externalBucketPrefixIdentifier = "://";
int externalBucketPrefixIndex = externalPath.lastIndexOf(externalBucketPrefixIdentifier);
Expand All @@ -215,7 +216,6 @@ private String replacePathWithLocal(String externalPath) throws IOException {
throw new RuntimeException(
"The provided external bucket could not be matched to a known source.");
}

int prefixLength = externalBucketPrefixIndex + externalBucketPrefixIdentifier.length();
return DIRECTORY_PREFIX + "/" + factoryType + "/" + externalPath.substring(prefixLength);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.POutput;
Expand All @@ -52,6 +53,8 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;

/**
* {@code BeamKafkaTable} represent a Kafka topic, as source or target. Need to extend to convert
* between {@code BeamSqlRow} and {@code KV<byte[], byte[]>}.
Expand All @@ -63,6 +66,7 @@ public abstract class BeamKafkaTable extends SchemaBaseBeamTable {

private TimestampPolicyFactory timestampPolicyFactory =
TimestampPolicyFactory.withProcessingTime();
private SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> consumerFactoryFn;
private String bootstrapServers;
private List<String> topics;
private List<TopicPartition> topicPartitions;
Expand All @@ -81,15 +85,17 @@ public BeamKafkaTable(Schema beamSchema, String bootstrapServers, List<String> t
}

public BeamKafkaTable(
Schema beamSchema,
String bootstrapServers,
List<String> topics,
TimestampPolicyFactory timestampPolicyFactory) {
Schema beamSchema,
String bootstrapServers,
List<String> topics,
TimestampPolicyFactory timestampPolicyFactory,
SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> consumerFactoryFn) {
super(beamSchema);
this.bootstrapServers = bootstrapServers;
this.topics = topics;
this.configUpdates = new HashMap<>();
this.timestampPolicyFactory = timestampPolicyFactory;
this.consumerFactoryFn = consumerFactoryFn;
}

public BeamKafkaTable(
Expand Down Expand Up @@ -155,6 +161,10 @@ protected KafkaIO.Read<byte[], byte[]> createKafkaRead() {
} else {
throw new InvalidTableException("One of topics and topicPartitions must be configurated.");
}
if (consumerFactoryFn != null) {
kafkaRead = kafkaRead.withConsumerFactoryFn(consumerFactoryFn);
}

return kafkaRead;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import org.apache.beam.sdk.extensions.sql.TableUtils;
import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable;
Expand All @@ -39,10 +41,14 @@
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializer;
import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializers;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.util.InstanceBuilder;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Splitter;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.apache.kafka.clients.consumer.Consumer;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
import org.joda.time.format.PeriodFormat;
Expand Down Expand Up @@ -147,6 +153,33 @@ public BeamSqlTable buildBeamSqlTable(Table table) {
}
}

SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> consumerFactoryFnClass;
if (properties.has("consumer.factory.fn")) {
String consumerFactoryFnAsString = properties.get("consumer.factory.fn").asText();
if (consumerFactoryFnAsString.contains("KerberosConsumerFactoryFn")) {
if (!properties.has("consumer.factory.fn.params") || !properties.get("consumer.factory.fn.params").has("krb5Location")) {
throw new RuntimeException("KerberosConsumerFactoryFn requires a krb5Location parameter, but none was set.");
}
}
try {
consumerFactoryFnClass =
InstanceBuilder.ofType(
new TypeDescriptor<
SerializableFunction<
Map<String, Object>, Consumer<byte[], byte[]>>>() {
})
.fromClassName(properties.get("consumer.factory.fn").asText())
.withArg(String.class,
Objects
.requireNonNull(properties.get("consumer.factory.fn.params")
.get("krb5Location")
.asText()))
.build();
} catch (Exception e) {
throw new RuntimeException("Unable to construct the ConsumerFactoryFn class.", e.getMessage());
}
}

BeamKafkaTable kafkaTable = null;
if (Schemas.isNestedSchema(schema)) {
Optional<PayloadSerializer> serializer =
Expand All @@ -158,7 +191,7 @@ public BeamSqlTable buildBeamSqlTable(Table table) {
TableUtils.convertNode2Map(properties)));
kafkaTable =
new NestedPayloadKafkaTable(
schema, bootstrapServers, topics, serializer, timestampPolicyFactory);
schema, bootstrapServers, topics, serializer, timestampPolicyFactory, consumerFactoryFnClass);
} else {
/*
* CSV is handled separately because multiple rows can be produced from a single message, which
Expand All @@ -168,14 +201,14 @@ public BeamSqlTable buildBeamSqlTable(Table table) {
*/
if (payloadFormat.orElse("csv").equals("csv")) {
kafkaTable =
new BeamKafkaCSVTable(schema, bootstrapServers, topics, timestampPolicyFactory);
new BeamKafkaCSVTable(schema, bootstrapServers, topics, timestampPolicyFactory, consumerFactoryFnClass);
} else {
PayloadSerializer serializer =
PayloadSerializers.getSerializer(
payloadFormat.get(), schema, TableUtils.convertNode2Map(properties));
kafkaTable =
new PayloadSerializerKafkaTable(
schema, bootstrapServers, topics, serializer, timestampPolicyFactory);
schema, bootstrapServers, topics, serializer, timestampPolicyFactory, consumerFactoryFnClass);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable;
import org.apache.beam.sdk.io.kafka.KafkaRecord;
Expand All @@ -33,6 +34,7 @@
import org.apache.beam.sdk.schemas.io.payloads.PayloadSerializer;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptor;
Expand All @@ -41,6 +43,7 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.header.Headers;
Expand All @@ -62,7 +65,8 @@ public NestedPayloadKafkaTable(
bootstrapServers,
topics,
payloadSerializer,
TimestampPolicyFactory.withLogAppendTime());
TimestampPolicyFactory.withLogAppendTime(),
null);
}

public NestedPayloadKafkaTable(
Expand All @@ -71,7 +75,23 @@ public NestedPayloadKafkaTable(
List<String> topics,
Optional<PayloadSerializer> payloadSerializer,
TimestampPolicyFactory timestampPolicyFactory) {
super(beamSchema, bootstrapServers, topics, timestampPolicyFactory);
this(
beamSchema,
bootstrapServers,
topics,
payloadSerializer,
timestampPolicyFactory,
/*consumerFactoryFn=*/null);
}

public NestedPayloadKafkaTable(
Schema beamSchema,
String bootstrapServers,
List<String> topics,
Optional<PayloadSerializer> payloadSerializer,
TimestampPolicyFactory timestampPolicyFactory,
SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> consumerFactoryFn) {
super(beamSchema, bootstrapServers, topics, timestampPolicyFactory, consumerFactoryFn);

checkArgument(Schemas.isNestedSchema(schema));
Schemas.validateNestedSchema(schema);
Expand Down
2 changes: 2 additions & 0 deletions sdks/java/io/expansion-service/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ dependencies {
permitUnusedDeclared project(":sdks:java:io:kafka") // BEAM-11761
implementation project(":sdks:java:io:kafka:upgrade")
permitUnusedDeclared project(":sdks:java:io:kafka:upgrade") // BEAM-11761
implementation project(":sdks:java:extensions:kafka-factories")
permitUnusedDeclared project(":sdks:java:extensions:kafka-factories")

if (JavaVersion.current().compareTo(JavaVersion.VERSION_11) >= 0 && project.findProperty('testJavaVersion') != '8') {
// iceberg ended support for Java 8 in 1.7.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -94,6 +95,7 @@
import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.Manual;
import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.MonotonicallyIncreasing;
import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators.WallTime;
import org.apache.beam.sdk.util.InstanceBuilder;
import org.apache.beam.sdk.util.Preconditions;
import org.apache.beam.sdk.util.construction.PTransformMatchers;
import org.apache.beam.sdk.util.construction.ReplacementOutputs;
Expand Down Expand Up @@ -930,6 +932,34 @@ static <K, V> void setupExternalBuilder(
builder.setOffsetDeduplication(false);
builder.setRedistributeByRecordKey(false);
}

if (config.consumerFactoryFnClass != null) {
if (config.consumerFactoryFnClass.contains("KerberosConsumerFactoryFn")) {
try {
if (!config.consumerFactoryFnParams.containsKey("krb5Location")) {
throw new IllegalArgumentException(
"The KerberosConsumerFactoryFn requires a location for the krb5.conf file. "
+ "Please provide either a GCS location or Google Secret Manager location for this file.");
}
String krb5Location = config.consumerFactoryFnParams.get("krb5Location");
builder.setConsumerFactoryFn(
InstanceBuilder.ofType(
new TypeDescriptor<
SerializableFunction<
Map<String, Object>, Consumer<byte[], byte[]>>>() {})
.fromClassName(config.consumerFactoryFnClass)
.withArg(String.class, Objects.requireNonNull(krb5Location))
.build());
} catch (Exception e) {
throw new RuntimeException(
"Unable to construct FactoryFn "
+ config.consumerFactoryFnClass
+ ": "
+ e.getMessage(),
e);
}
}
}
}

private static <T> Coder<T> resolveCoder(Class<Deserializer<T>> deserializer) {
Expand Down Expand Up @@ -1000,6 +1030,8 @@ public static class Configuration {
private Boolean offsetDeduplication;
private Boolean redistributeByRecordKey;
private Long dynamicReadPollIntervalSeconds;
private String consumerFactoryFnClass;
private Map<String, String> consumerFactoryFnParams;

public void setConsumerConfig(Map<String, String> consumerConfig) {
this.consumerConfig = consumerConfig;
Expand Down Expand Up @@ -1068,6 +1100,14 @@ public void setRedistributeByRecordKey(Boolean redistributeByRecordKey) {
public void setDynamicReadPollIntervalSeconds(Long dynamicReadPollIntervalSeconds) {
this.dynamicReadPollIntervalSeconds = dynamicReadPollIntervalSeconds;
}

public void setConsumerFactoryFnClass(String consumerFactoryFnClass) {
this.consumerFactoryFnClass = consumerFactoryFnClass;
}

public void setConsumerFactoryFnParams(Map<String, String> consumerFactoryFnParams) {
this.consumerFactoryFnParams = consumerFactoryFnParams;
}
}
}

Expand Down
48 changes: 30 additions & 18 deletions sdks/python/apache_beam/io/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@

# pytype: skip-file

import collections
import typing

import numpy as np
Expand All @@ -110,22 +111,21 @@

ReadFromKafkaSchema = typing.NamedTuple(
'ReadFromKafkaSchema',
[
('consumer_config', typing.Mapping[str, str]),
('topics', typing.List[str]),
('key_deserializer', str),
('value_deserializer', str),
('start_read_time', typing.Optional[int]),
('max_num_records', typing.Optional[int]),
('max_read_time', typing.Optional[int]),
('commit_offset_in_finalize', bool),
('timestamp_policy', str),
('consumer_polling_timeout', typing.Optional[int]),
('redistribute', typing.Optional[bool]),
('redistribute_num_keys', typing.Optional[np.int32]),
('allow_duplicates', typing.Optional[bool]),
('dynamic_read_poll_interval_seconds', typing.Optional[int]),
])
[('consumer_config', typing.Mapping[str, str]),
('topics', typing.List[str]), ('key_deserializer', str),
('value_deserializer', str), ('start_read_time', typing.Optional[int]),
('max_num_records', typing.Optional[int]),
('max_read_time', typing.Optional[int]),
('commit_offset_in_finalize', bool), ('timestamp_policy', str),
('consumer_polling_timeout', typing.Optional[int]),
('redistribute', typing.Optional[bool]),
('redistribute_num_keys', typing.Optional[np.int32]),
('allow_duplicates', typing.Optional[bool]),
('dynamic_read_poll_interval_seconds', typing.Optional[int]),
('consumer_factory_fn_class', typing.Optional[str]),
(
'consumer_factory_fn_params',
typing.Optional[collections.abc.Mapping[str, str]])])


def default_io_expansion_service(append_args=None):
Expand Down Expand Up @@ -173,7 +173,10 @@ def __init__(
redistribute_num_keys=np.int32(0),
allow_duplicates=False,
dynamic_read_poll_interval_seconds: typing.Optional[int] = None,
):
consumer_factory_fn_class: typing.Optional[
collections.abc.Mapping] = None,
consumer_factory_fn_params: typing.Optional[
collections.abc.Mapping] = None):
"""
Initializes a read operation from Kafka.

Expand Down Expand Up @@ -216,6 +219,13 @@ def __init__(
:param dynamic_read_poll_interval_seconds: The interval in seconds at which
to check for new partitions. If not None, dynamic partition discovery
is enabled.
:param consumer_factory_fn_class: A fully qualified classpath to an
existing provided consumerFactoryFn. If not None, this will construct
Kafka consumers with a custom configuration.
:param consumer_factory_fn_params: A map which specifies the parameters for
the provided consumer_factory_fn_class. If not None, the values in this
map will be used when constructing the consumer_factory_fn_class object.
This cannot be null if the consumer_factory_fn_class is not null.
"""
if timestamp_policy not in [ReadFromKafka.processing_time_policy,
ReadFromKafka.create_time_policy,
Expand All @@ -242,7 +252,9 @@ def __init__(
redistribute_num_keys=redistribute_num_keys,
allow_duplicates=allow_duplicates,
dynamic_read_poll_interval_seconds=
dynamic_read_poll_interval_seconds)),
dynamic_read_poll_interval_seconds,
consumer_factory_fn_class=consumer_factory_fn_class,
consumer_factory_fn_params=consumer_factory_fn_params)),
expansion_service or default_io_expansion_service())


Expand Down
Loading
Loading