Skip to content

Commit 96e2b08

Browse files
committed
[#83] Add array_contains function implementation
1 parent 7fbbf67 commit 96e2b08

File tree

8 files changed

+244
-21
lines changed

8 files changed

+244
-21
lines changed

core/impl/src/main/java/com/blazebit/query/impl/calcite/CalciteDataSource.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*/
55
package com.blazebit.query.impl.calcite;
66

7+
import com.blazebit.query.impl.calcite.function.ArrayContainsFunction;
78
import org.apache.calcite.adapter.enumerable.EnumerableConvention;
89
import org.apache.calcite.adapter.enumerable.EnumerableRel;
910
import org.apache.calcite.adapter.java.JavaTypeFactory;
@@ -16,6 +17,7 @@
1617
import org.apache.calcite.jdbc.CalcitePrepare;
1718
import org.apache.calcite.jdbc.CalciteSchema;
1819
import org.apache.calcite.jdbc.Driver;
20+
import org.apache.calcite.linq4j.Queryable;
1921
import org.apache.calcite.plan.Convention;
2022
import org.apache.calcite.plan.RelOptCluster;
2123
import org.apache.calcite.plan.RelOptPlanner;
@@ -73,6 +75,8 @@ public boolean shouldConvertRaggedUnionTypesToVarying() {
7375
}
7476
this.typeFactory = new CustomJavaTypeFactory( typeSystem );
7577
this.rootSchema = CalciteSchema.createRootSchema( true );
78+
// Custom functions for Blaze-Query
79+
this.rootSchema.plus().add( "array_contains", new ArrayContainsFunction());
7680
}
7781

7882
public SchemaPlus getRootSchema() {
@@ -160,6 +164,19 @@ public <T> T unwrap(Class<T> iface) throws SQLException {
160164
}
161165

162166
private static class MyCalcitePrepareImpl extends CalcitePrepareImpl {
167+
168+
@Override public <T> CalciteSignature<T> prepareQueryable(Context context, Queryable<T> queryable) {
169+
return super.prepareQueryable(wrapContext(context), queryable);
170+
}
171+
172+
@Override public <T> CalciteSignature<T> prepareSql(Context context, Query<T> query, Type elementType, long maxRowCount) {
173+
return super.prepareSql(wrapContext(context), query, elementType, maxRowCount);
174+
}
175+
176+
private Context wrapContext(Context context) {
177+
return context;
178+
}
179+
163180
@Override
164181
protected CalcitePreparingStmt getPreparingStmt(Context context, Type elementType, CalciteCatalogReader catalogReader, RelOptPlanner planner) {
165182
final JavaTypeFactory typeFactory = context.getTypeFactory();
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Blazebit
4+
*/
5+
package com.blazebit.query.impl.calcite.function;
6+
7+
import com.google.common.collect.ImmutableList;
8+
import org.apache.calcite.adapter.enumerable.CallImplementor;
9+
import org.apache.calcite.adapter.enumerable.RexImpTable;
10+
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
11+
import org.apache.calcite.linq4j.tree.Expression;
12+
import org.apache.calcite.linq4j.tree.Expressions;
13+
import org.apache.calcite.rel.type.RelDataType;
14+
import org.apache.calcite.rel.type.RelDataTypeFactory;
15+
import org.apache.calcite.rex.RexCall;
16+
import org.apache.calcite.runtime.FlatLists;
17+
import org.apache.calcite.schema.FunctionParameter;
18+
import org.apache.calcite.schema.ImplementableFunction;
19+
import org.apache.calcite.schema.ScalarFunction;
20+
import org.apache.calcite.sql.SqlCollation;
21+
import org.apache.calcite.sql.SqlUtil;
22+
import org.apache.calcite.sql.type.SqlTypeUtil;
23+
24+
import java.lang.reflect.Method;
25+
import java.math.BigDecimal;
26+
import java.nio.charset.Charset;
27+
import java.util.Comparator;
28+
import java.util.List;
29+
30+
import static org.apache.calcite.adapter.enumerable.EnumUtils.generateCollatorExpression;
31+
import static org.apache.calcite.linq4j.Nullness.castNonNull;
32+
import static org.apache.calcite.util.Static.RESOURCE;
33+
34+
/**
35+
* @author Christian Beikov
36+
* @since 1.0.0
37+
*/
38+
public class ArrayContainsFunction implements ScalarFunction, ImplementableFunction, CallImplementor {
39+
40+
private final Method basicMethod;
41+
private final Method decimalMethod;
42+
private final Method comparatorMethod;
43+
private final List<FunctionParameter> parameters = ImmutableList.of(
44+
new SimpleFunctionParameter(0, "array", Object.class, false),
45+
new SimpleFunctionParameter(1, "element", Object.class, false)
46+
);
47+
48+
public ArrayContainsFunction() {
49+
try {
50+
this.basicMethod = ArrayContainsFunction.class.getDeclaredMethod("arrayContains", List.class, Object.class);
51+
this.decimalMethod = ArrayContainsFunction.class.getDeclaredMethod("arrayContains", List.class, BigDecimal.class);
52+
this.comparatorMethod = ArrayContainsFunction.class.getDeclaredMethod("arrayContains", List.class, String.class, Comparator.class);
53+
}
54+
catch (NoSuchMethodException e) {
55+
throw new RuntimeException( e );
56+
}
57+
}
58+
59+
@Override
60+
public List<FunctionParameter> getParameters() {
61+
return parameters;
62+
}
63+
64+
@Override
65+
public RelDataType getReturnType(RelDataTypeFactory typeFactory) {
66+
return typeFactory.createJavaType(Boolean.class);
67+
}
68+
69+
@Override
70+
public CallImplementor getImplementor() {
71+
return this;
72+
}
73+
74+
@Override
75+
public Expression implement(RexToLixTranslator translator, RexCall call, RexImpTable.NullAs nullAs) {
76+
final List<Expression> expressions = translator.translateList(call.getOperands());
77+
final RelDataType operandType0 = castNonNull(call.getOperands().get(0).getType().getComponentType());
78+
final RelDataType operandType1 = call.getOperands().get(1).getType();
79+
final Expression fieldComparator;
80+
if (SqlTypeUtil.inCharFamily(operandType0) && SqlTypeUtil.inCharFamily(operandType1)) {
81+
Charset cs0 = operandType0.getCharset();
82+
Charset cs1 = operandType1.getCharset();
83+
assert (null != cs0) && (null != cs1)
84+
: "An implicit or explicit charset should have been set";
85+
if (!cs0.equals(cs1)) {
86+
throw SqlUtil.newContextException(call.pos, RESOURCE.incompatibleCharset("array_contains", cs0.name(), cs1.name()));
87+
}
88+
89+
SqlCollation collation0 = operandType0.getCollation();
90+
SqlCollation collation1 = operandType1.getCollation();
91+
assert (null != collation0) && (null != collation1)
92+
: "An implicit or explicit collation should have been set";
93+
94+
// Validation will occur inside getCoercibilityDyadicOperator...
95+
SqlCollation resultCol = SqlCollation.getCoercibilityDyadicOperator(collation0, collation1);
96+
fieldComparator = generateCollatorExpression(resultCol);
97+
}
98+
else {
99+
fieldComparator = null;
100+
}
101+
if (expressions.get(1).getType() == BigDecimal.class) {
102+
return Expressions.call(decimalMethod, expressions);
103+
}
104+
105+
return fieldComparator == null
106+
? Expressions.call(basicMethod, expressions)
107+
: Expressions.call(comparatorMethod, FlatLists.append(expressions, fieldComparator));
108+
}
109+
110+
public static Boolean arrayContains(List<?> list, Object element) {
111+
return list == null ? null : list.contains(element);
112+
}
113+
114+
public static Boolean arrayContains(List<BigDecimal> list, BigDecimal element) {
115+
if (list == null) {
116+
return null;
117+
}
118+
for (BigDecimal t : list) {
119+
if (t.compareTo(element) == 0) {
120+
return true;
121+
}
122+
}
123+
return false;
124+
}
125+
126+
public static Boolean arrayContains(List<String> list, String element, Comparator<String> comparator) {
127+
if (list == null) {
128+
return null;
129+
}
130+
for (String t : list) {
131+
if (comparator.compare(t, element) == 0) {
132+
return true;
133+
}
134+
}
135+
return false;
136+
}
137+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Blazebit
4+
*/
5+
package com.blazebit.query.impl.calcite.function;
6+
7+
import org.apache.calcite.rel.type.RelDataType;
8+
import org.apache.calcite.rel.type.RelDataTypeFactory;
9+
import org.apache.calcite.schema.FunctionParameter;
10+
11+
/**
12+
* @author Christian Beikov
13+
* @since 1.0.0
14+
*/
15+
public class SimpleFunctionParameter implements FunctionParameter {
16+
17+
private final int ordinal;
18+
private final String name;
19+
private final Class<?> type;
20+
private final boolean optional;
21+
22+
public SimpleFunctionParameter(int ordinal, String name, Class<?> type, boolean optional) {
23+
this.ordinal = ordinal;
24+
this.name = name;
25+
this.type = type;
26+
this.optional = optional;
27+
}
28+
29+
@Override
30+
public int getOrdinal() {
31+
return ordinal;
32+
}
33+
34+
@Override
35+
public String getName() {
36+
return name;
37+
}
38+
39+
@Override
40+
public RelDataType getType(RelDataTypeFactory typeFactory) {
41+
return typeFactory.createJavaType(type);
42+
}
43+
44+
@Override
45+
public boolean isOptional() {
46+
return optional;
47+
}
48+
}

core/impl/src/test/java/com/blazebit/query/impl/EnumArrayTest.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ public class EnumArrayTest {
3131
public void test() {
3232
JiraCloudAdminUser model = new JiraCloudAdminUser(
3333
"u1",
34-
List.of( PlatformRole.ADMIN )
34+
List.of( PlatformRole.ADMIN ),
35+
36+
List.of("u1", "admin")
3537
);
3638
QueryContextBuilder queryContextBuilder = Queries.createQueryContextBuilder();
3739
queryContextBuilder.registerSchemaObject( JiraCloudAdminUser.class, new DataFetcher<>() {
@@ -49,19 +51,24 @@ public List<JiraCloudAdminUser> fetch(DataFetchContext context) {
4951
try (QueryContext queryContext = queryContextBuilder.build()) {
5052
try (QuerySession session = queryContext.createSession()) {
5153
TypedQuery<Object[]> query = session.createQuery(
52-
"select r = 'ADMIN' " +
54+
"select r = 'ADMIN', array_contains(u.platformRoles, 'ADMIN'), array_contains(u.aliases, 'ADMIN') " +
5355
"from JiraCloudAdminUser u " +
5456
"cross join unnest(u.platformRoles) r"
5557
);
5658
List<Object[]> result = query.getResultList();
5759
assertEquals( 1, result.size() );
60+
assertEquals( 3, result.get( 0 ).length );
61+
assertEquals( true, result.get( 0 )[0] );
62+
assertEquals( true, result.get( 0 )[1] );
63+
assertEquals( false, result.get( 0 )[2] );
5864
}
5965
}
6066
}
6167

6268
public record JiraCloudAdminUser(
6369
String id,
64-
List<PlatformRole> platformRoles) {
70+
List<PlatformRole> platformRoles,
71+
List<String> aliases) {
6572
}
6673

6774
enum PlatformRole {

examples/app/src/main/java/com/blazebit/query/app/Main.java

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,13 @@
5858
import com.blazebit.query.connector.azure.graph.AzureGraphApplication;
5959
import com.blazebit.query.connector.azure.graph.AzureGraphClientAccessor;
6060
import com.blazebit.query.connector.azure.graph.AzureGraphConditionalAccessPolicy;
61-
import com.blazebit.query.connector.azure.graph.AzureGraphConnectorConfig;
6261
import com.blazebit.query.connector.azure.graph.AzureGraphIncident;
6362
import com.blazebit.query.connector.azure.graph.AzureGraphManagedDevice;
6463
import com.blazebit.query.connector.azure.graph.AzureGraphOrganization;
6564
import com.blazebit.query.connector.azure.graph.AzureGraphServicePlanInfo;
6665
import com.blazebit.query.connector.azure.graph.AzureGraphUser;
6766
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceBlobServiceProperties;
6867
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceManagedCluster;
69-
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceManagerConnectorConfig;
70-
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceManagerPostgreSqlManagerConnectorConfig;
7168
import com.blazebit.query.connector.azure.resourcemanager.AzureResourcePostgreSqlFlexibleServer;
7269
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceManagerPostgreSqlManager;
7370
import com.blazebit.query.connector.azure.resourcemanager.AzureResourcePostgreSqlFlexibleServerBackup;
@@ -78,7 +75,6 @@
7875
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceVirtualMachine;
7976
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceVirtualNetwork;
8077
import com.blazebit.query.connector.github.graphql.GitHubBranchProtectionRule;
81-
import com.blazebit.query.connector.github.graphql.GitHubConnectorConfig;
8278
import com.blazebit.query.connector.github.graphql.GitHubGraphQlClient;
8379
import com.blazebit.query.connector.github.graphql.GitHubOrganization;
8480
import com.blazebit.query.connector.github.graphql.GitHubPullRequest;
@@ -87,9 +83,7 @@
8783
import com.blazebit.query.connector.github.v0314.model.OrganizationSimple;
8884
import com.blazebit.query.connector.github.v0314.model.ShortBranch;
8985
import com.blazebit.query.connector.github.v0314.model.Team;
90-
import com.blazebit.query.connector.gitlab.GitlabConnectorConfig;
9186
import com.blazebit.query.connector.gitlab.GitlabGraphQlClient;
92-
import com.blazebit.query.connector.gitlab.GitlabGraphQlConnectorConfig;
9387
import com.blazebit.query.connector.gitlab.GitlabGroup;
9488
import com.blazebit.query.connector.gitlab.GitlabMergeRequest;
9589
import com.blazebit.query.connector.gitlab.GitlabProject;
@@ -192,7 +186,7 @@ public static void main(String[] args) throws Exception {
192186
try (EntityManagerFactory emf = Persistence.createEntityManagerFactory( "default" )) {
193187
SessionFactory sf = emf.unwrap( SessionFactory.class );
194188
sf.inTransaction( s -> {
195-
s.persist( new TestEntity( 1L, "Test", new TestEmbeddable( "text1", "text2" ) ) );
189+
s.persist( new TestEntity( 1L, "Test", new TestEmbeddable( "text1", "text2" ), Set.of(TestEnum.A, TestEnum.B) ) );
196190
} );
197191

198192
CriteriaBuilderFactory cbf = Criteria.getDefault().createCriteriaBuilderFactory( emf );
@@ -202,11 +196,11 @@ public static void main(String[] args) throws Exception {
202196
EntityViewManager evm = defaultConfiguration.createEntityViewManager( cbf );
203197

204198
QueryContextBuilder queryContextBuilder = Queries.createQueryContextBuilder();
205-
queryContextBuilder.setProperty( AzureResourceManagerConnectorConfig.AZURE_RESOURCE_MANAGER.getPropertyName(), createResourceManager());
206-
queryContextBuilder.setPropertyProvider( AzureResourceManagerPostgreSqlManagerConnectorConfig.POSTGRESQL_MANAGER.getPropertyName(),
207-
Main::createPostgreSqlManagers );
208-
queryContextBuilder.setProperty( "serverParameters", List.of("ssl_min_protocol_version", "authentication_timeout"));
209-
queryContextBuilder.setProperty( AzureGraphConnectorConfig.GRAPH_SERVICE_CLIENT.getPropertyName(), createGraphServiceClient());
199+
// queryContextBuilder.setProperty( AzureResourceManagerConnectorConfig.AZURE_RESOURCE_MANAGER.getPropertyName(), createResourceManager());
200+
// queryContextBuilder.setPropertyProvider( AzureResourceManagerPostgreSqlManagerConnectorConfig.POSTGRESQL_MANAGER.getPropertyName(),
201+
// Main::createPostgreSqlManagers );
202+
// queryContextBuilder.setProperty( "serverParameters", List.of("ssl_min_protocol_version", "authentication_timeout"));
203+
// queryContextBuilder.setProperty( AzureGraphConnectorConfig.GRAPH_SERVICE_CLIENT.getPropertyName(), createGraphServiceClient());
210204
// queryContextBuilder.setProperty( AwsConnectorConfig.ACCOUNT.getPropertyName(), createAwsAccount() );
211205
// queryContextBuilder.setProperty( GoogleDirectoryConnectorConfig.GOOGLE_DIRECTORY_SERVICE.getPropertyName(), createGoogleDirectory() );
212206
// queryContextBuilder.setProperty( GoogleDriveConnectorConfig.GOOGLE_DRIVE_SERVICE.getPropertyName(), createGoogleDrive() );
@@ -216,11 +210,11 @@ public static void main(String[] args) throws Exception {
216210
// queryContextBuilder.setProperty( "jqlQuery", "statusCategory != Done");
217211
// queryContextBuilder.setProperty( JiraCloudAdminConnectorConfig.API_CLIENT.getPropertyName(), createJiraCloudAdminOrganizationApiClient());
218212
queryContextBuilder.setProperty( EntityViewConnectorConfig.ENTITY_VIEW_MANAGER.getPropertyName(), evm );
219-
queryContextBuilder.setProperty( GitlabConnectorConfig.GITLAB_API.getPropertyName(), createGitlabApi());
220-
queryContextBuilder.setProperty( GitlabGraphQlConnectorConfig.GITLAB_GRAPHQL_CLIENT.getPropertyName(), createGitlabGraphQLClient());
213+
// queryContextBuilder.setProperty( GitlabConnectorConfig.GITLAB_API.getPropertyName(), createGitlabApi());
214+
// queryContextBuilder.setProperty( GitlabGraphQlConnectorConfig.GITLAB_GRAPHQL_CLIENT.getPropertyName(), createGitlabGraphQLClient());
221215
// queryContextBuilder.setProperty(KandjiConnectorConfig.API_CLIENT.getPropertyName(), createKandjiApiClient());
222216
// queryContextBuilder.setProperty(GithubConnectorConfig.GITHUB.getPropertyName(), createGithub());
223-
queryContextBuilder.setProperty( GitHubConnectorConfig.GITHUB_GRAPHQL_CLIENT.getPropertyName(), createGitHubGraphQLClient());
217+
// queryContextBuilder.setProperty( GitHubConnectorConfig.GITHUB_GRAPHQL_CLIENT.getPropertyName(), createGitHubGraphQLClient());
224218
// queryContextBuilder.setProperty(com.blazebit.query.connector.github.v0314.GithubConnectorConfig.API_CLIENT.getPropertyName(), createGitHubApiClient());
225219

226220
// Azure Resource manager
@@ -385,8 +379,8 @@ public static void main(String[] args) throws Exception {
385379
// testGitHub( session );
386380
// testGitHubOpenAPI( session );
387381
// testKandji( session );
388-
// testEntityView( session );
389-
testAzureGraph( session );
382+
testEntityView( session );
383+
// testAzureGraph( session );
390384
// testAzureResourceManager( session );
391385
}
392386
}
@@ -931,6 +925,10 @@ private static void testEntityView(QuerySession session) {
931925
"select t.id, e.text1 from " + name( TestEntityView.class ) + " t, unnest(t.elements) e" );
932926
List<Object[]> entityViewResult = entityViewQuery.getResultList();
933927
print( entityViewResult, "id", "text1" );
928+
TypedQuery<Object[]> entityViewQuery2 = session.createQuery(
929+
"select t.id, array_contains(t.enums, 'A') from " + name( TestEntityView.class ) + " t" );
930+
List<Object[]> entityViewResult2 = entityViewQuery2.getResultList();
931+
print( entityViewResult2, "id", "enums" );
934932
}
935933

936934
private static void testAzureGraph(QuerySession session) {

examples/app/src/main/java/com/blazebit/query/app/TestEntity.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,18 @@ public class TestEntity {
2323
TestEmbeddable embedded;
2424
@ElementCollection
2525
Set<TestEmbeddable> elements;
26+
@ElementCollection
27+
Set<TestEnum> enums;
2628

2729
public TestEntity() {
2830
}
2931

30-
public TestEntity(Long id, String name, TestEmbeddable embedded) {
32+
public TestEntity(Long id, String name, TestEmbeddable embedded, Set<TestEnum> enums) {
3133
this.id = id;
3234
this.name = name;
3335
this.embedded = embedded;
3436
this.elements = new HashSet<>( Collections.singletonList( embedded ) );
37+
this.enums = new HashSet<>( enums );
3538
}
3639

3740
@Override

examples/app/src/main/java/com/blazebit/query/app/TestEntityView.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,6 @@ public interface TestEntityView {
1919
TestEmbeddableView getEmbedded();
2020

2121
Set<TestEmbeddableView> getElements();
22+
23+
Set<TestEnum> getEnums();
2224
}

0 commit comments

Comments
 (0)