Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import java.util.Collection;
import java.util.Map;
import java.util.function.Function;

import org.apache.kafka.common.TopicPartition;
import org.springframework.lang.Nullable;

/**
* Listeners that implement this interface are provided with a
Expand All @@ -29,8 +29,8 @@
*
* @author Gary Russell
* @author Soby Chacko
* @author Borahm Lee
* @since 1.1
*
*/
public interface ConsumerSeekAware {

Expand All @@ -39,15 +39,17 @@ public interface ConsumerSeekAware {
* {@code ConcurrentMessageListenerContainer} or the same listener instance in multiple
* containers listeners should store the callback in a {@code ThreadLocal} or a map keyed
* by the thread.
*
* @param callback the callback.
*/
default void registerSeekCallback(ConsumerSeekCallback callback) {
}

/**
* When using group management, called when partition assignments change.
*
* @param assignments the new assignments and their current offsets.
* @param callback the callback to perform an initial seek after assignment.
* @param callback the callback to perform an initial seek after assignment.
*/
default void onPartitionsAssigned(Map<TopicPartition, Long> assignments, ConsumerSeekCallback callback) {
}
Expand All @@ -56,6 +58,7 @@ default void onPartitionsAssigned(Map<TopicPartition, Long> assignments, Consume
* When using group management, called when partition assignments are revoked.
* Listeners should discard any callback saved from
* {@link #registerSeekCallback(ConsumerSeekCallback)} on this thread.
*
* @param partitions the partitions that have been revoked.
* @since 2.3
*/
Expand All @@ -65,8 +68,9 @@ default void onPartitionsRevoked(Collection<TopicPartition> partitions) {
/**
* If the container is configured to emit idle container events, this method is called
* when the container idle event is emitted - allowing a seek operation.
*
* @param assignments the new assignments and their current offsets.
* @param callback the callback to perform a seek.
* @param callback the callback to perform a seek.
*/
default void onIdleContainer(Map<TopicPartition, Long> assignments, ConsumerSeekCallback callback) {
}
Expand All @@ -75,6 +79,7 @@ default void onIdleContainer(Map<TopicPartition, Long> assignments, ConsumerSeek
* When using manual partition assignment, called when the first poll has completed;
* useful when using {@code auto.offset.reset=latest} and you need to wait until the
* initial position has been established.
*
* @since 2.8.8
*/
default void onFirstPoll() {
Expand All @@ -83,6 +88,7 @@ default void onFirstPoll() {
/**
* Called when the listener consumer terminates allowing implementations to clean up
* state, such as thread locals.
*
* @since 2.4
*/
default void unregisterSeekCallback() {
Expand All @@ -101,9 +107,10 @@ interface ConsumerSeekCallback {
* queue the seek operation to the consumer. The queued seek will occur after any
* pending offset commits. The consumer must be currently assigned the specified
* partition.
* @param topic the topic.
*
* @param topic the topic.
* @param partition the partition.
* @param offset the offset (absolute).
* @param offset the offset (absolute).
*/
void seek(String topic, int partition, long offset);

Expand All @@ -117,8 +124,9 @@ interface ConsumerSeekCallback {
* queue the seek operation to the consumer. The queued seek will occur after any
* pending offset commits. The consumer must be currently assigned the specified
* partition.
* @param topic the topic.
* @param partition the partition.
*
* @param topic the topic.
* @param partition the partition.
* @param offsetComputeFunction function to compute the absolute offset to seek to.
* @since 3.2.0
*/
Expand All @@ -132,7 +140,8 @@ interface ConsumerSeekCallback {
* the seek operation to the consumer. The queued seek will occur after
* any pending offset commits. The consumer must be currently assigned the
* specified partition.
* @param topic the topic.
*
* @param topic the topic.
* @param partition the partition.
*/
void seekToBeginning(String topic, int partition);
Expand All @@ -145,6 +154,7 @@ interface ConsumerSeekCallback {
* queue the seek operation to the consumer for each
* {@link TopicPartition}. The seek will occur after any pending offset commits.
* The consumer must be currently assigned the specified partition(s).
*
* @param partitions the {@link TopicPartition}s.
* @since 2.3.4
*/
Expand All @@ -160,7 +170,8 @@ default void seekToBeginning(Collection<TopicPartition> partitions) {
* the seek operation to the consumer. The queued seek will occur after any
* pending offset commits. The consumer must be currently assigned the specified
* partition.
* @param topic the topic.
*
* @param topic the topic.
* @param partition the partition.
*/
void seekToEnd(String topic, int partition);
Expand All @@ -173,6 +184,7 @@ default void seekToBeginning(Collection<TopicPartition> partitions) {
* the seek operation to the consumer for each {@link TopicPartition}. The queued
* seek(s) will occur after any pending offset commits. The consumer must be
* currently assigned the specified partition(s).
*
* @param partitions the {@link TopicPartition}s.
* @since 2.3.4
*/
Expand All @@ -187,12 +199,13 @@ default void seekToEnd(Collection<TopicPartition> partitions) {
* perform the seek immediately on the consumer. When called from elsewhere, queue
* the seek operation. The queued seek will occur after any pending offset
* commits. The consumer must be currently assigned the specified partition.
* @param topic the topic.
*
* @param topic the topic.
* @param partition the partition.
* @param offset the offset; positive values are relative to the start, negative
* values are relative to the end, unless toCurrent is true.
* @param offset the offset; positive values are relative to the start, negative
* values are relative to the end, unless toCurrent is true.
* @param toCurrent true for the offset to be relative to the current position
* rather than the beginning or end.
* rather than the beginning or end.
* @since 2.3
*/
void seekRelative(String topic, int partition, long offset, boolean toCurrent);
Expand All @@ -207,11 +220,12 @@ default void seekToEnd(Collection<TopicPartition> partitions) {
* commits. The consumer must be currently assigned the specified partition. Use
* {@link #seekToTimestamp(Collection, long)} when seeking multiple partitions
* because the offset lookup is blocking.
* @param topic the topic.
*
* @param topic the topic.
* @param partition the partition.
* @param timestamp the time stamp.
* @since 2.3
* @see #seekToTimestamp(Collection, long)
* @since 2.3
*/
void seekToTimestamp(String topic, int partition, long timestamp);

Expand All @@ -223,12 +237,26 @@ default void seekToEnd(Collection<TopicPartition> partitions) {
* perform the seek immediately on the consumer. When called from elsewhere, queue
* the seek operation. The queued seek will occur after any pending offset
* commits. The consumer must be currently assigned the specified partition.
*
* @param topicPartitions the topic/partitions.
* @param timestamp the time stamp.
* @param timestamp the time stamp.
* @since 2.3
*/
void seekToTimestamp(Collection<TopicPartition> topicPartitions, long timestamp);


/**
* Retrieve the group ID associated with this consumer seek callback, if available.
* This method returns {@code null} by default, indicating that the group ID is not specified.
* Implementations may override this method to provide a specific group ID value.
*
* @return the consumer group ID.
* @since 3.3
*/
@Nullable
default String getGroupId() {
return null;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
* @author Raphael Rösch
* @author Christian Mergenthaler
* @author Mikael Carlstedt
* @author Borahm Lee
*/
public class KafkaMessageListenerContainer<K, V> // NOSONAR line count
extends AbstractMessageListenerContainer<K, V> implements ConsumerPauseResumeEventPublisher {
Expand Down Expand Up @@ -681,7 +682,7 @@ private final class ListenerConsumer implements SchedulingAwareRunnable, Consume

private final TransactionTemplate transactionTemplate;

private final String consumerGroupId = getGroupId();
private final String consumerGroupId = KafkaMessageListenerContainer.this.getGroupId();

private final TaskScheduler taskScheduler;

Expand Down Expand Up @@ -1362,8 +1363,8 @@ protected void initialize() {
}
publishConsumerStartingEvent();
this.consumerThread = Thread.currentThread();
setupSeeks();
KafkaUtils.setConsumerGroupId(this.consumerGroupId);
setupSeeks();
this.count = 0;
this.last = System.currentTimeMillis();
initAssignedPartitions();
Expand Down Expand Up @@ -1906,7 +1907,7 @@ private void wrapUp(@Nullable Throwable throwable) {
this.consumerSeekAwareListener.onPartitionsRevoked(partitions);
this.consumerSeekAwareListener.unregisterSeekCallback();
}
this.logger.info(() -> getGroupId() + ": Consumer stopped");
this.logger.info(() -> KafkaMessageListenerContainer.this.getGroupId() + ": Consumer stopped");
publishConsumerStoppedEvent(throwable);
}

Expand Down Expand Up @@ -2693,7 +2694,7 @@ private RuntimeException doInvokeRecordListener(final ConsumerRecord<K, V> cReco
Observation observation = KafkaListenerObservation.LISTENER_OBSERVATION.observation(
this.containerProperties.getObservationConvention(),
DefaultKafkaListenerObservationConvention.INSTANCE,
() -> new KafkaRecordReceiverContext(cRecord, getListenerId(), getClientId(), getGroupId(),
() -> new KafkaRecordReceiverContext(cRecord, getListenerId(), getClientId(), KafkaMessageListenerContainer.this.getGroupId(),
this::clusterId),
this.observationRegistry);
return observation.observe(() -> {
Expand Down Expand Up @@ -3327,6 +3328,11 @@ public void seekToTimestamp(Collection<TopicPartition> topicParts, long timestam
topicParts.forEach(tp -> seekToTimestamp(tp.topic(), tp.partition(), timestamp));
}

@Override
public String getGroupId() {
return KafkaMessageListenerContainer.this.getGroupId();
}

@Override
public String toString() {
return "KafkaMessageListenerContainer.ListenerConsumer ["
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package org.springframework.kafka.listener;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.kafka.annotation.EnableKafka;
import org.springframework.kafka.annotation.KafkaListener;
import org.springframework.kafka.config.ConcurrentKafkaListenerContainerFactory;
import org.springframework.kafka.core.ConsumerFactory;
import org.springframework.kafka.core.DefaultKafkaConsumerFactory;
import org.springframework.kafka.core.DefaultKafkaProducerFactory;
import org.springframework.kafka.core.KafkaTemplate;
import org.springframework.kafka.core.ProducerFactory;
import org.springframework.kafka.listener.AbstractConsumerSeekAwareTests.Config.MultiGroupListener;
import org.springframework.kafka.test.EmbeddedKafkaBroker;
import org.springframework.kafka.test.context.EmbeddedKafka;
import org.springframework.kafka.test.utils.KafkaTestUtils;
import org.springframework.stereotype.Component;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;

@DirtiesContext
@SpringJUnitConfig
@EmbeddedKafka(topics = {AbstractConsumerSeekAwareTests.TOPIC}, partitions = 3)
public class AbstractConsumerSeekAwareTests {

static final String TOPIC = "Seek";

@Autowired
Config config;

@Autowired
KafkaTemplate<String, String> template;

@Autowired
MultiGroupListener multiGroupListener;

@Test
public void seekForAllGroups() throws Exception {
template.send(TOPIC, "test-data");
template.send(TOPIC, "test-data");
assertTrue(MultiGroupListener.latch1.await(10, TimeUnit.SECONDS));
assertTrue(MultiGroupListener.latch2.await(10, TimeUnit.SECONDS));

MultiGroupListener.latch1 = new CountDownLatch(2);
MultiGroupListener.latch2 = new CountDownLatch(2);

multiGroupListener.seekToBeginning();
assertTrue(MultiGroupListener.latch1.await(10, TimeUnit.SECONDS));
assertTrue(MultiGroupListener.latch2.await(10, TimeUnit.SECONDS));
}

@Test
public void seekForSpecificGroup() throws Exception {
template.send(TOPIC, "test-data");
template.send(TOPIC, "test-data");
assertTrue(MultiGroupListener.latch1.await(10, TimeUnit.SECONDS));
assertTrue(MultiGroupListener.latch2.await(10, TimeUnit.SECONDS));

MultiGroupListener.latch1 = new CountDownLatch(2);
MultiGroupListener.latch2 = new CountDownLatch(2);

multiGroupListener.seekToBeginningForGroup("group2");
assertThat(MultiGroupListener.latch1.getCount()).isEqualTo(2);
assertTrue(MultiGroupListener.latch2.await(10, TimeUnit.SECONDS));
}

@EnableKafka
@Configuration
static class Config {

@Autowired
EmbeddedKafkaBroker broker;

@Bean
public ConcurrentKafkaListenerContainerFactory<String, String> kafkaListenerContainerFactory(
ConsumerFactory<String, String> consumerFactory) {
ConcurrentKafkaListenerContainerFactory<String, String> factory = new ConcurrentKafkaListenerContainerFactory<>();
factory.setConsumerFactory(consumerFactory);
return factory;
}

@Bean
ConsumerFactory<String, String> consumerFactory() {
return new DefaultKafkaConsumerFactory<>(KafkaTestUtils.consumerProps("test-group", "false", this.broker));
}

@Bean
ProducerFactory<String, String> producerFactory() {
return new DefaultKafkaProducerFactory<>(KafkaTestUtils.producerProps(this.broker));
}

@Bean
KafkaTemplate<String, String> template(ProducerFactory<String, String> pf) {
return new KafkaTemplate<>(pf);
}

@Component
static class MultiGroupListener extends AbstractConsumerSeekAware {

static CountDownLatch latch1 = new CountDownLatch(2);
static CountDownLatch latch2 = new CountDownLatch(2);

@KafkaListener(groupId = "group1", topics = TOPIC)
void listenForGroup1(String in) {
// System.out.printf("[group1] in = %s\n", in); // TODO remove
latch1.countDown();
}

@KafkaListener(groupId = "group2", topics = TOPIC)
void listenForGroup2(String in) {
// System.out.printf("[group2] in = %s\n", in); // TODO remove
latch2.countDown();
}

public void seekToBeginningForGroup(String groupIdForSeek) {
getCallbacksAndTopics().forEach((cb, topics) -> {
if (groupIdForSeek.equals(cb.getGroupId())) {
topics.forEach(tp -> cb.seekToBeginning(tp.topic(), tp.partition()));
}
});
}
}
}

}