5757import org .springframework .kafka .core .ConsumerFactory ;
5858import org .springframework .kafka .test .utils .KafkaTestUtils ;
5959import org .springframework .test .annotation .DirtiesContext ;
60+ import org .springframework .test .annotation .DirtiesContext .ClassMode ;
6061import org .springframework .test .context .junit .jupiter .SpringJUnitConfig ;
6162
6263/**
6566 *
6667 */
6768@ SpringJUnitConfig
68- @ DirtiesContext
69+ @ DirtiesContext ( classMode = ClassMode . AFTER_EACH_TEST_METHOD )
6970public class SubBatchPerPartitionTests {
7071
7172 private static final String CONTAINER_ID = "container" ;
@@ -86,7 +87,8 @@ public class SubBatchPerPartitionTests {
8687 */
8788 @ SuppressWarnings ("unchecked" )
8889 @ Test
89- public void discardRemainingRecordsFromPollAndSeek () throws Exception {
90+ void discardRemainingRecordsFromPollAndSeek () throws Exception {
91+ this .registry .getListenerContainer (CONTAINER_ID ).start ();
9092 assertThat (this .config .deliveryLatch .await (10 , TimeUnit .SECONDS )).isTrue ();
9193 assertThat (this .config .commitLatch .await (10 , TimeUnit .SECONDS )).isTrue ();
9294 assertThat (this .config .pollLatch .await (10 , TimeUnit .SECONDS )).isTrue ();
@@ -98,6 +100,25 @@ public void discardRemainingRecordsFromPollAndSeek() throws Exception {
98100 inOrder .verify (this .consumer , times (3 )).commitSync (any (), eq (Duration .ofSeconds (60 )));
99101 inOrder .verify (this .consumer ).poll (Duration .ofMillis (ContainerProperties .DEFAULT_POLL_TIMEOUT ));
100102 assertThat (this .config .contents ).contains ("foo" , "bar" , "baz" , "qux" , "fiz" , "buz" );
103+ this .registry .stop ();
104+ }
105+
106+ @ SuppressWarnings ("unchecked" )
107+ @ Test
108+ void withFilter () throws Exception {
109+ this .registry .getListenerContainer (CONTAINER_ID + ".filtered" ).start ();
110+ assertThat (this .config .deliveryLatch .await (10 , TimeUnit .SECONDS )).isTrue ();
111+ assertThat (this .config .commitLatch .await (10 , TimeUnit .SECONDS )).isTrue ();
112+ assertThat (this .config .pollLatch .await (10 , TimeUnit .SECONDS )).isTrue ();
113+ this .registry .stop ();
114+ assertThat (this .config .closeLatch .await (10 , TimeUnit .SECONDS )).isTrue ();
115+ InOrder inOrder = inOrder (this .consumer );
116+ inOrder .verify (this .consumer ).subscribe (any (Collection .class ), any (ConsumerRebalanceListener .class ));
117+ inOrder .verify (this .consumer ).poll (Duration .ofMillis (ContainerProperties .DEFAULT_POLL_TIMEOUT ));
118+ inOrder .verify (this .consumer , times (3 )).commitSync (any (), eq (Duration .ofSeconds (60 )));
119+ inOrder .verify (this .consumer ).poll (Duration .ofMillis (ContainerProperties .DEFAULT_POLL_TIMEOUT ));
120+ assertThat (this .config .filtered ).contains ("bar" , "qux" , "buz" );
121+ this .registry .stop ();
101122 }
102123
103124 @ Configuration
@@ -106,6 +127,8 @@ public static class Config {
106127
107128 private final List <String > contents = new ArrayList <>();
108129
130+ private final List <String > filtered = new ArrayList <>();
131+
109132 private final CountDownLatch pollLatch = new CountDownLatch (2 );
110133
111134 private final CountDownLatch deliveryLatch = new CountDownLatch (3 );
@@ -114,19 +137,27 @@ public static class Config {
114137
115138 private final CountDownLatch closeLatch = new CountDownLatch (1 );
116139
117- @ KafkaListener (id = CONTAINER_ID , topics = "foo" )
140+ @ KafkaListener (id = CONTAINER_ID , topics = "foo" , autoStartup = "false" )
118141 public void foo (List <String > in ) {
119142 contents .addAll (in );
120143 this .deliveryLatch .countDown ();
121144 }
122145
146+ @ KafkaListener (id = CONTAINER_ID + ".filtered" , topics = "foo" , autoStartup = "false" ,
147+ containerFactory = "filteredFactory" )
148+ public void filtered (List <String > in ) {
149+ filtered .addAll (in );
150+ this .deliveryLatch .countDown ();
151+ }
152+
123153 @ SuppressWarnings ({ "rawtypes" })
124154 @ Bean
125155 public ConsumerFactory consumerFactory () {
126156 ConsumerFactory consumerFactory = mock (ConsumerFactory .class );
127157 final Consumer consumer = consumer ();
128- given (consumerFactory .createConsumer (CONTAINER_ID , "" , "-0" , KafkaTestUtils .defaultPropertyOverrides ()))
129- .willReturn (consumer );
158+ given (consumerFactory .createConsumer (any (), eq ("" ), eq ("-0" ),
159+ eq (KafkaTestUtils .defaultPropertyOverrides ())))
160+ .willReturn (consumer );
130161 return consumerFactory ;
131162 }
132163
@@ -152,10 +183,13 @@ public Consumer consumer() {
152183 records1 .put (topicPartition2 , Arrays .asList (
153184 new ConsumerRecord ("foo" , 2 , 0L , 0L , TimestampType .NO_TIMESTAMP_TYPE , 0 , 0 , 0 , null , "fiz" ),
154185 new ConsumerRecord ("foo" , 2 , 1L , 0L , TimestampType .NO_TIMESTAMP_TYPE , 0 , 0 , 0 , null , "buz" )));
155- final AtomicInteger which = new AtomicInteger ();
186+ final ThreadLocal < AtomicInteger > which = new ThreadLocal <> ();
156187 willAnswer (i -> {
157188 this .pollLatch .countDown ();
158- switch (which .getAndIncrement ()) {
189+ if (which .get () == null ) {
190+ which .set (new AtomicInteger ());
191+ }
192+ switch (which .get ().getAndIncrement ()) {
159193 case 0 :
160194 return new ConsumerRecords (records1 );
161195 default :
@@ -191,6 +225,19 @@ public ConcurrentKafkaListenerContainerFactory kafkaListenerContainerFactory() {
191225 return factory ;
192226 }
193227
228+ @ SuppressWarnings ({ "rawtypes" , "unchecked" })
229+ @ Bean
230+ public ConcurrentKafkaListenerContainerFactory filteredFactory () {
231+ ConcurrentKafkaListenerContainerFactory factory = new ConcurrentKafkaListenerContainerFactory ();
232+ factory .setConsumerFactory (consumerFactory ());
233+ factory .getContainerProperties ().setAckOnError (false );
234+ factory .setBatchListener (true );
235+ factory .getContainerProperties ().setMissingTopicsFatal (false );
236+ factory .getContainerProperties ().setSubBatchPerPartition (true );
237+ factory .setRecordFilterStrategy (rec -> rec .offset () == 0 );
238+ return factory ;
239+ }
240+
194241 }
195242
196243}
0 commit comments