From 9b8f42d074eb2d09ec20862a4d923d854fcb2028 Mon Sep 17 00:00:00 2001 From: Apostolos Rousas Date: Mon, 22 Sep 2025 14:04:32 +0300 Subject: [PATCH] Set trackTotalHits in searchMsearchRequest Signed-off-by: Apostolos Rousas --- .../data/client/osc/RequestConverter.java | 6 ++ .../data/client/osc/RequestConverterTest.java | 74 +++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/spring-data-opensearch/src/main/java/org/opensearch/data/client/osc/RequestConverter.java b/spring-data-opensearch/src/main/java/org/opensearch/data/client/osc/RequestConverter.java index 80e87665..8adb10f9 100644 --- a/spring-data-opensearch/src/main/java/org/opensearch/data/client/osc/RequestConverter.java +++ b/spring-data-opensearch/src/main/java/org/opensearch/data/client/osc/RequestConverter.java @@ -1394,6 +1394,12 @@ public MsearchRequest searchMsearchRequest( query.getScriptedFields().forEach(scriptedField -> bb.scriptFields(scriptedField.getFieldName(), sf -> sf.script(getScript(scriptedField.getScriptData())))); + if (query.getTrackTotalHits() != null) { + bb.trackTotalHits(th -> th.enabled(query.getTrackTotalHits())); + } else if (query.getTrackTotalHitsUpTo() != null) { + bb.trackTotalHits(th -> th.count(query.getTrackTotalHitsUpTo())); + } + if (query instanceof NativeQuery nativeQuery) { prepareNativeSearch(nativeQuery, bb); } diff --git a/spring-data-opensearch/src/test/java/org/opensearch/data/client/osc/RequestConverterTest.java b/spring-data-opensearch/src/test/java/org/opensearch/data/client/osc/RequestConverterTest.java index 6ad97d44..9a9ef615 100644 --- a/spring-data-opensearch/src/test/java/org/opensearch/data/client/osc/RequestConverterTest.java +++ b/spring-data-opensearch/src/test/java/org/opensearch/data/client/osc/RequestConverterTest.java @@ -17,11 +17,14 @@ import static org.assertj.core.api.Assertions.*; +import java.util.ArrayList; import java.util.List; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.opensearch.client.json.jackson.JacksonJsonpMapper; import org.opensearch.client.opensearch._types.Refresh; +import org.opensearch.client.opensearch.core.msearch.RequestItem; +import org.opensearch.client.opensearch.core.search.TrackHits; import org.springframework.data.annotation.Id; import org.springframework.data.elasticsearch.annotations.Document; import org.springframework.data.elasticsearch.annotations.Field; @@ -86,6 +89,77 @@ void refreshSetByDeleteRequest() { assertThat(deleteByQueryRequest.refresh()).isEqualTo(Refresh.True); } + + @Test // #542 + @DisplayName("should set track_total_hits to true on searchMsearchRequest") + void shouldSetTrackTotalTrueOnMultiSearch() { + var query = new NativeQueryBuilder() + .withQuery(Queries.matchAllQuery().toQuery()) + .withTrackTotalHits(true) + .build(); + + var multiSearchQueryParameters = new ArrayList(); + multiSearchQueryParameters.add(new OpenSearchTemplate.MultiSearchQueryParameter(query, SampleEntity.class, IndexCoordinates.of("foo"))); + + var searchRequest = requestConverter.searchMsearchRequest(multiSearchQueryParameters, null); + + List searches = searchRequest.searches(); + assertThat(searches).hasSize(1); + + TrackHits trackTotalHits = searches.getFirst().body().trackTotalHits(); + assertThat(trackTotalHits).isNotNull(); + assertThat(trackTotalHits.isCount()).isFalse(); + assertThat(trackTotalHits.isEnabled()).isTrue(); + assertThat(trackTotalHits.enabled()).isTrue(); + } + + @Test // #542 + @DisplayName("should set track_total_hits to false on searchMsearchRequest") + void shouldSetTrackTotalFalseOnMultiSearch() { + var query = new NativeQueryBuilder() + .withQuery(Queries.matchAllQuery().toQuery()) + .withTrackTotalHits(false) + .build(); + + var multiSearchQueryParameters = new ArrayList(); + multiSearchQueryParameters.add(new OpenSearchTemplate.MultiSearchQueryParameter(query, SampleEntity.class, IndexCoordinates.of("foo"))); + + var searchRequest = requestConverter.searchMsearchRequest(multiSearchQueryParameters, null); + + List searches = searchRequest.searches(); + assertThat(searches).hasSize(1); + + TrackHits trackTotalHits = searches.getFirst().body().trackTotalHits(); + assertThat(trackTotalHits).isNotNull(); + assertThat(trackTotalHits.isCount()).isFalse(); + assertThat(trackTotalHits.isEnabled()).isTrue(); + assertThat(trackTotalHits.enabled()).isFalse(); + } + + @Test // #542 + @DisplayName("should set track_total_hits to count value on searchMsearchRequest") + void shouldSetTrackTotalCountValueOnMultiSearch() { + int countValue = 5000; + var query = new NativeQueryBuilder() + .withQuery(Queries.matchAllQuery().toQuery()) + .withTrackTotalHitsUpTo(countValue) + .build(); + + var multiSearchQueryParameters = new ArrayList(); + multiSearchQueryParameters.add(new OpenSearchTemplate.MultiSearchQueryParameter(query, SampleEntity.class, IndexCoordinates.of("foo"))); + + var searchRequest = requestConverter.searchMsearchRequest(multiSearchQueryParameters, null); + + List searches = searchRequest.searches(); + assertThat(searches).hasSize(1); + + TrackHits trackTotalHits = searches.getFirst().body().trackTotalHits(); + assertThat(trackTotalHits).isNotNull(); + assertThat(trackTotalHits.isEnabled()).isFalse(); + assertThat(trackTotalHits.isCount()).isTrue(); + assertThat(trackTotalHits.count()).isEqualTo(countValue); + } + @Document(indexName = "does-not-matter") static class SampleEntity { @Nullable