Skip to content
Merged
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019-2023 the original author or authors.
* Copyright 2019-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,7 @@
* having to keep track of the callbacks itself.
*
* @author Gary Russell
* @author Borahm Lee
* @since 2.3
*
*/
Expand All @@ -46,43 +47,59 @@ public abstract class AbstractConsumerSeekAware implements ConsumerSeekAware {

@Override
public void registerSeekCallback(ConsumerSeekCallback callback) {
this.callbackForThread.put(Thread.currentThread(), callback);
if (matchGroupId()) {
this.callbackForThread.put(Thread.currentThread(), callback);
}
}

@Override
public void onPartitionsAssigned(Map<TopicPartition, Long> assignments, ConsumerSeekCallback callback) {
ConsumerSeekCallback threadCallback = this.callbackForThread.get(Thread.currentThread());
if (threadCallback != null) {
assignments.keySet().forEach(tp -> {
this.callbacks.put(tp, threadCallback);
this.callbacksToTopic.computeIfAbsent(threadCallback, key -> new LinkedList<>()).add(tp);
});
if (matchGroupId()) {
ConsumerSeekCallback threadCallback = this.callbackForThread.get(Thread.currentThread());
if (threadCallback != null) {
assignments.keySet()
.forEach(tp -> {
this.callbacks.put(tp, threadCallback);
this.callbacksToTopic.computeIfAbsent(threadCallback, key -> new LinkedList<>())
.add(tp);
});
}
}
}

@Override
public void onPartitionsRevoked(Collection<TopicPartition> partitions) {
partitions.forEach(tp -> {
ConsumerSeekCallback removed = this.callbacks.remove(tp);
if (removed != null) {
List<TopicPartition> topics = this.callbacksToTopic.get(removed);
if (topics != null) {
topics.remove(tp);
if (topics.size() == 0) {
this.callbacksToTopic.remove(removed);
if (matchGroupId()) {
partitions.forEach(tp -> {
ConsumerSeekCallback removed = this.callbacks.remove(tp);
if (removed != null) {
List<TopicPartition> topics = this.callbacksToTopic.get(removed);
if (topics != null) {
topics.remove(tp);
if (topics.size() == 0) {
this.callbacksToTopic.remove(removed);
}
}
}
}
});
});
}
}

@Override
public void unregisterSeekCallback() {
this.callbackForThread.remove(Thread.currentThread());
if (matchGroupId()) {
this.callbackForThread.remove(Thread.currentThread());
}
}

@Override
public boolean matchGroupId() {
return true;
}

/**
* Return the callback for the specified topic/partition.
*
* @param topicPartition the topic/partition.
* @return the callback (or null if there is no assignment).
*/
Expand All @@ -93,6 +110,7 @@ protected ConsumerSeekCallback getSeekCallbackFor(TopicPartition topicPartition)

/**
* The map of callbacks for all currently assigned partitions.
*
* @return the map.
*/
protected Map<TopicPartition, ConsumerSeekCallback> getSeekCallbacks() {
Expand All @@ -101,6 +119,7 @@ protected Map<TopicPartition, ConsumerSeekCallback> getSeekCallbacks() {

/**
* Return the currently registered callbacks and their associated {@link TopicPartition}(s).
*
* @return the map of callbacks and partitions.
* @since 2.6
*/
Expand All @@ -110,6 +129,7 @@ protected Map<ConsumerSeekCallback, List<TopicPartition>> getCallbacksAndTopics(

/**
* Seek all assigned partitions to the beginning.
*
* @since 2.6
*/
public void seekToBeginning() {
Expand All @@ -118,6 +138,7 @@ public void seekToBeginning() {

/**
* Seek all assigned partitions to the end.
*
* @since 2.6
*/
public void seekToEnd() {
Expand All @@ -126,6 +147,7 @@ public void seekToEnd() {

/**
* Seek all assigned partitions to the offset represented by the timestamp.
*
* @param time the time to seek to.
* @since 2.6
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
*
* @author Gary Russell
* @author Soby Chacko
* @author Borahm Lee
* @since 1.1
*
*/
Expand Down Expand Up @@ -88,6 +89,16 @@ default void onFirstPoll() {
default void unregisterSeekCallback() {
}

/**
* Determine if the consumer group ID for seeking matches the expected value.
*
* @return true if the group ID matches, false otherwise.
* @since 3.3
*/
default boolean matchGroupId() {
return false;
}

/**
* A callback that a listener can invoke to seek to a specific offset.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
* @author Raphael Rösch
* @author Christian Mergenthaler
* @author Mikael Carlstedt
* @author Borahm Lee
*/
public class KafkaMessageListenerContainer<K, V> // NOSONAR line count
extends AbstractMessageListenerContainer<K, V> implements ConsumerPauseResumeEventPublisher {
Expand Down Expand Up @@ -1362,8 +1363,8 @@ protected void initialize() {
}
publishConsumerStartingEvent();
this.consumerThread = Thread.currentThread();
setupSeeks();
KafkaUtils.setConsumerGroupId(this.consumerGroupId);
setupSeeks();
this.count = 0;
this.last = System.currentTimeMillis();
initAssignedPartitions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,27 @@

package org.springframework.kafka.listener;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.kafka.common.TopicPartition;
import org.junit.jupiter.api.Test;

import org.springframework.kafka.listener.ConsumerSeekAware.ConsumerSeekCallback;
import org.springframework.kafka.test.utils.KafkaTestUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;

/**
* @author Gary Russell
* @author Borahm Lee
* @since 2.6
*
*/
public class ConsumerSeekAwareTests {

Expand All @@ -51,16 +50,15 @@ class CSA extends AbstractConsumerSeekAware {
var exec1 = Executors.newSingleThreadExecutor();
var exec2 = Executors.newSingleThreadExecutor();
var cb1 = mock(ConsumerSeekCallback.class);
var cb2 = mock(ConsumerSeekCallback.class);
var cb2 = mock(ConsumerSeekCallback.class);
var first = new AtomicBoolean(true);
var map1 = new LinkedHashMap<>(Map.of(new TopicPartition("foo", 0), 0L, new TopicPartition("foo", 1), 0L));
var map2 = new LinkedHashMap<>(Map.of(new TopicPartition("foo", 2), 0L, new TopicPartition("foo", 3), 0L));
var register = (Callable<Void>) () -> {
if (first.getAndSet(false)) {
csa.registerSeekCallback(cb1);
csa.onPartitionsAssigned(map1, null);
}
else {
} else {
csa.registerSeekCallback(cb2);
csa.onPartitionsAssigned(map2, null);
}
Expand All @@ -80,8 +78,7 @@ class CSA extends AbstractConsumerSeekAware {
var revoke1 = (Callable<Void>) () -> {
if (!first.getAndSet(true)) {
csa.onPartitionsRevoked(Collections.singletonList(map1.keySet().iterator().next()));
}
else {
} else {
csa.onPartitionsRevoked(Collections.singletonList(map2.keySet().iterator().next()));
}
return null;
Expand All @@ -96,8 +93,7 @@ class CSA extends AbstractConsumerSeekAware {
var revoke2 = (Callable<Void>) () -> {
if (first.getAndSet(false)) {
csa.onPartitionsRevoked(Collections.singletonList(map1.keySet().iterator().next()));
}
else {
} else {
csa.onPartitionsRevoked(Collections.singletonList(map2.keySet().iterator().next()));
}
return null;
Expand All @@ -118,4 +114,28 @@ class CSA extends AbstractConsumerSeekAware {
exec2.shutdown();
}

@SuppressWarnings("unchecked")
@Test
void notMatchedGroupId() throws ExecutionException, InterruptedException {
class CSA extends AbstractConsumerSeekAware {
@Override
public boolean matchGroupId() {
return false;
}
}

AbstractConsumerSeekAware csa = new CSA();
var exec = Executors.newSingleThreadExecutor();
var register = (Callable<Void>) () -> {
csa.registerSeekCallback(mock(ConsumerSeekCallback.class));
csa.onPartitionsAssigned(Map.of(new TopicPartition("baz", 0), 0L), null);
return null;
};
exec.submit(register).get();
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbackForThread", Map.class)).isEmpty();
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbacks", Map.class)).isEmpty();
assertThat(KafkaTestUtils.getPropertyValue(csa, "callbacksToTopic", Map.class)).isEmpty();
exec.shutdown();
}

}