diff --git a/core/impl/src/main/java/com/blazebit/query/impl/calcite/CalciteDataSource.java b/core/impl/src/main/java/com/blazebit/query/impl/calcite/CalciteDataSource.java index 95ed1982..bebb2165 100644 --- a/core/impl/src/main/java/com/blazebit/query/impl/calcite/CalciteDataSource.java +++ b/core/impl/src/main/java/com/blazebit/query/impl/calcite/CalciteDataSource.java @@ -4,6 +4,7 @@ */ package com.blazebit.query.impl.calcite; +import com.blazebit.query.impl.calcite.function.ArrayContainsFunction; import org.apache.calcite.adapter.enumerable.EnumerableConvention; import org.apache.calcite.adapter.enumerable.EnumerableRel; import org.apache.calcite.adapter.java.JavaTypeFactory; @@ -16,6 +17,7 @@ import org.apache.calcite.jdbc.CalcitePrepare; import org.apache.calcite.jdbc.CalciteSchema; import org.apache.calcite.jdbc.Driver; +import org.apache.calcite.linq4j.Queryable; import org.apache.calcite.plan.Convention; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptPlanner; @@ -73,6 +75,8 @@ public boolean shouldConvertRaggedUnionTypesToVarying() { } this.typeFactory = new CustomJavaTypeFactory( typeSystem ); this.rootSchema = CalciteSchema.createRootSchema( true ); + // Custom functions for Blaze-Query + this.rootSchema.plus().add( "array_contains", new ArrayContainsFunction()); } public SchemaPlus getRootSchema() { @@ -160,6 +164,19 @@ public T unwrap(Class iface) throws SQLException { } private static class MyCalcitePrepareImpl extends CalcitePrepareImpl { + + @Override public CalciteSignature prepareQueryable(Context context, Queryable queryable) { + return super.prepareQueryable(wrapContext(context), queryable); + } + + @Override public CalciteSignature prepareSql(Context context, Query query, Type elementType, long maxRowCount) { + return super.prepareSql(wrapContext(context), query, elementType, maxRowCount); + } + + private Context wrapContext(Context context) { + return context; + } + @Override protected CalcitePreparingStmt getPreparingStmt(Context context, Type elementType, CalciteCatalogReader catalogReader, RelOptPlanner planner) { final JavaTypeFactory typeFactory = context.getTypeFactory(); diff --git a/core/impl/src/main/java/com/blazebit/query/impl/calcite/function/ArrayContainsFunction.java b/core/impl/src/main/java/com/blazebit/query/impl/calcite/function/ArrayContainsFunction.java new file mode 100644 index 00000000..a4664696 --- /dev/null +++ b/core/impl/src/main/java/com/blazebit/query/impl/calcite/function/ArrayContainsFunction.java @@ -0,0 +1,137 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Blazebit + */ +package com.blazebit.query.impl.calcite.function; + +import com.google.common.collect.ImmutableList; +import org.apache.calcite.adapter.enumerable.CallImplementor; +import org.apache.calcite.adapter.enumerable.RexImpTable; +import org.apache.calcite.adapter.enumerable.RexToLixTranslator; +import org.apache.calcite.linq4j.tree.Expression; +import org.apache.calcite.linq4j.tree.Expressions; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.runtime.FlatLists; +import org.apache.calcite.schema.FunctionParameter; +import org.apache.calcite.schema.ImplementableFunction; +import org.apache.calcite.schema.ScalarFunction; +import org.apache.calcite.sql.SqlCollation; +import org.apache.calcite.sql.SqlUtil; +import org.apache.calcite.sql.type.SqlTypeUtil; + +import java.lang.reflect.Method; +import java.math.BigDecimal; +import java.nio.charset.Charset; +import java.util.Comparator; +import java.util.List; + +import static org.apache.calcite.adapter.enumerable.EnumUtils.generateCollatorExpression; +import static org.apache.calcite.linq4j.Nullness.castNonNull; +import static org.apache.calcite.util.Static.RESOURCE; + +/** + * @author Christian Beikov + * @since 1.0.0 + */ +public class ArrayContainsFunction implements ScalarFunction, ImplementableFunction, CallImplementor { + + private final Method basicMethod; + private final Method decimalMethod; + private final Method comparatorMethod; + private final List parameters = ImmutableList.of( + new SimpleFunctionParameter(0, "array", Object.class, false), + new SimpleFunctionParameter(1, "element", Object.class, false) + ); + + public ArrayContainsFunction() { + try { + this.basicMethod = ArrayContainsFunction.class.getDeclaredMethod("arrayContains", List.class, Object.class); + this.decimalMethod = ArrayContainsFunction.class.getDeclaredMethod("arrayContains", List.class, BigDecimal.class); + this.comparatorMethod = ArrayContainsFunction.class.getDeclaredMethod("arrayContains", List.class, String.class, Comparator.class); + } + catch (NoSuchMethodException e) { + throw new RuntimeException( e ); + } + } + + @Override + public List getParameters() { + return parameters; + } + + @Override + public RelDataType getReturnType(RelDataTypeFactory typeFactory) { + return typeFactory.createJavaType(Boolean.class); + } + + @Override + public CallImplementor getImplementor() { + return this; + } + + @Override + public Expression implement(RexToLixTranslator translator, RexCall call, RexImpTable.NullAs nullAs) { + final List expressions = translator.translateList(call.getOperands()); + final RelDataType operandType0 = castNonNull(call.getOperands().get(0).getType().getComponentType()); + final RelDataType operandType1 = call.getOperands().get(1).getType(); + final Expression fieldComparator; + if (SqlTypeUtil.inCharFamily(operandType0) && SqlTypeUtil.inCharFamily(operandType1)) { + Charset cs0 = operandType0.getCharset(); + Charset cs1 = operandType1.getCharset(); + assert (null != cs0) && (null != cs1) + : "An implicit or explicit charset should have been set"; + if (!cs0.equals(cs1)) { + throw SqlUtil.newContextException(call.pos, RESOURCE.incompatibleCharset("array_contains", cs0.name(), cs1.name())); + } + + SqlCollation collation0 = operandType0.getCollation(); + SqlCollation collation1 = operandType1.getCollation(); + assert (null != collation0) && (null != collation1) + : "An implicit or explicit collation should have been set"; + + // Validation will occur inside getCoercibilityDyadicOperator... + SqlCollation resultCol = SqlCollation.getCoercibilityDyadicOperator(collation0, collation1); + fieldComparator = generateCollatorExpression(resultCol); + } + else { + fieldComparator = null; + } + if (expressions.get(1).getType() == BigDecimal.class) { + return Expressions.call(decimalMethod, expressions); + } + + return fieldComparator == null + ? Expressions.call(basicMethod, expressions) + : Expressions.call(comparatorMethod, FlatLists.append(expressions, fieldComparator)); + } + + public static Boolean arrayContains(List list, Object element) { + return list == null ? null : list.contains(element); + } + + public static Boolean arrayContains(List list, BigDecimal element) { + if (list == null) { + return null; + } + for (BigDecimal t : list) { + if (t.compareTo(element) == 0) { + return true; + } + } + return false; + } + + public static Boolean arrayContains(List list, String element, Comparator comparator) { + if (list == null) { + return null; + } + for (String t : list) { + if (comparator.compare(t, element) == 0) { + return true; + } + } + return false; + } +} diff --git a/core/impl/src/main/java/com/blazebit/query/impl/calcite/function/SimpleFunctionParameter.java b/core/impl/src/main/java/com/blazebit/query/impl/calcite/function/SimpleFunctionParameter.java new file mode 100644 index 00000000..65e3c128 --- /dev/null +++ b/core/impl/src/main/java/com/blazebit/query/impl/calcite/function/SimpleFunctionParameter.java @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Blazebit + */ +package com.blazebit.query.impl.calcite.function; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.schema.FunctionParameter; + +/** + * @author Christian Beikov + * @since 1.0.0 + */ +public class SimpleFunctionParameter implements FunctionParameter { + + private final int ordinal; + private final String name; + private final Class type; + private final boolean optional; + + public SimpleFunctionParameter(int ordinal, String name, Class type, boolean optional) { + this.ordinal = ordinal; + this.name = name; + this.type = type; + this.optional = optional; + } + + @Override + public int getOrdinal() { + return ordinal; + } + + @Override + public String getName() { + return name; + } + + @Override + public RelDataType getType(RelDataTypeFactory typeFactory) { + return typeFactory.createJavaType(type); + } + + @Override + public boolean isOptional() { + return optional; + } +} diff --git a/core/impl/src/test/java/com/blazebit/query/impl/EnumArrayTest.java b/core/impl/src/test/java/com/blazebit/query/impl/EnumArrayTest.java index a41996fa..53927cb3 100644 --- a/core/impl/src/test/java/com/blazebit/query/impl/EnumArrayTest.java +++ b/core/impl/src/test/java/com/blazebit/query/impl/EnumArrayTest.java @@ -31,7 +31,9 @@ public class EnumArrayTest { public void test() { JiraCloudAdminUser model = new JiraCloudAdminUser( "u1", - List.of( PlatformRole.ADMIN ) + List.of( PlatformRole.ADMIN ), + + List.of("u1", "admin") ); QueryContextBuilder queryContextBuilder = Queries.createQueryContextBuilder(); queryContextBuilder.registerSchemaObject( JiraCloudAdminUser.class, new DataFetcher<>() { @@ -49,19 +51,24 @@ public List fetch(DataFetchContext context) { try (QueryContext queryContext = queryContextBuilder.build()) { try (QuerySession session = queryContext.createSession()) { TypedQuery query = session.createQuery( - "select r = 'ADMIN' " + + "select r = 'ADMIN', array_contains(u.platformRoles, 'ADMIN'), array_contains(u.aliases, 'ADMIN') " + "from JiraCloudAdminUser u " + "cross join unnest(u.platformRoles) r" ); List result = query.getResultList(); assertEquals( 1, result.size() ); + assertEquals( 3, result.get( 0 ).length ); + assertEquals( true, result.get( 0 )[0] ); + assertEquals( true, result.get( 0 )[1] ); + assertEquals( false, result.get( 0 )[2] ); } } } public record JiraCloudAdminUser( String id, - List platformRoles) { + List platformRoles, + List aliases) { } enum PlatformRole { diff --git a/examples/app/src/main/java/com/blazebit/query/app/Main.java b/examples/app/src/main/java/com/blazebit/query/app/Main.java index abe0f196..3c5f1e9a 100644 --- a/examples/app/src/main/java/com/blazebit/query/app/Main.java +++ b/examples/app/src/main/java/com/blazebit/query/app/Main.java @@ -58,7 +58,6 @@ import com.blazebit.query.connector.azure.graph.AzureGraphApplication; import com.blazebit.query.connector.azure.graph.AzureGraphClientAccessor; import com.blazebit.query.connector.azure.graph.AzureGraphConditionalAccessPolicy; -import com.blazebit.query.connector.azure.graph.AzureGraphConnectorConfig; import com.blazebit.query.connector.azure.graph.AzureGraphIncident; import com.blazebit.query.connector.azure.graph.AzureGraphManagedDevice; import com.blazebit.query.connector.azure.graph.AzureGraphOrganization; @@ -66,8 +65,6 @@ import com.blazebit.query.connector.azure.graph.AzureGraphUser; import com.blazebit.query.connector.azure.resourcemanager.AzureResourceBlobServiceProperties; import com.blazebit.query.connector.azure.resourcemanager.AzureResourceManagedCluster; -import com.blazebit.query.connector.azure.resourcemanager.AzureResourceManagerConnectorConfig; -import com.blazebit.query.connector.azure.resourcemanager.AzureResourceManagerPostgreSqlManagerConnectorConfig; import com.blazebit.query.connector.azure.resourcemanager.AzureResourcePostgreSqlFlexibleServer; import com.blazebit.query.connector.azure.resourcemanager.AzureResourceManagerPostgreSqlManager; import com.blazebit.query.connector.azure.resourcemanager.AzureResourcePostgreSqlFlexibleServerBackup; @@ -78,7 +75,6 @@ import com.blazebit.query.connector.azure.resourcemanager.AzureResourceVirtualMachine; import com.blazebit.query.connector.azure.resourcemanager.AzureResourceVirtualNetwork; import com.blazebit.query.connector.github.graphql.GitHubBranchProtectionRule; -import com.blazebit.query.connector.github.graphql.GitHubConnectorConfig; import com.blazebit.query.connector.github.graphql.GitHubGraphQlClient; import com.blazebit.query.connector.github.graphql.GitHubOrganization; import com.blazebit.query.connector.github.graphql.GitHubPullRequest; @@ -87,9 +83,7 @@ import com.blazebit.query.connector.github.v0314.model.OrganizationSimple; import com.blazebit.query.connector.github.v0314.model.ShortBranch; import com.blazebit.query.connector.github.v0314.model.Team; -import com.blazebit.query.connector.gitlab.GitlabConnectorConfig; import com.blazebit.query.connector.gitlab.GitlabGraphQlClient; -import com.blazebit.query.connector.gitlab.GitlabGraphQlConnectorConfig; import com.blazebit.query.connector.gitlab.GitlabGroup; import com.blazebit.query.connector.gitlab.GitlabMergeRequest; import com.blazebit.query.connector.gitlab.GitlabProject; @@ -192,7 +186,7 @@ public static void main(String[] args) throws Exception { try (EntityManagerFactory emf = Persistence.createEntityManagerFactory( "default" )) { SessionFactory sf = emf.unwrap( SessionFactory.class ); sf.inTransaction( s -> { - s.persist( new TestEntity( 1L, "Test", new TestEmbeddable( "text1", "text2" ) ) ); + s.persist( new TestEntity( 1L, "Test", new TestEmbeddable( "text1", "text2" ), Set.of(TestEnum.A, TestEnum.B) ) ); } ); CriteriaBuilderFactory cbf = Criteria.getDefault().createCriteriaBuilderFactory( emf ); @@ -202,11 +196,11 @@ public static void main(String[] args) throws Exception { EntityViewManager evm = defaultConfiguration.createEntityViewManager( cbf ); QueryContextBuilder queryContextBuilder = Queries.createQueryContextBuilder(); - queryContextBuilder.setProperty( AzureResourceManagerConnectorConfig.AZURE_RESOURCE_MANAGER.getPropertyName(), createResourceManager()); - queryContextBuilder.setPropertyProvider( AzureResourceManagerPostgreSqlManagerConnectorConfig.POSTGRESQL_MANAGER.getPropertyName(), - Main::createPostgreSqlManagers ); - queryContextBuilder.setProperty( "serverParameters", List.of("ssl_min_protocol_version", "authentication_timeout")); - queryContextBuilder.setProperty( AzureGraphConnectorConfig.GRAPH_SERVICE_CLIENT.getPropertyName(), createGraphServiceClient()); +// queryContextBuilder.setProperty( AzureResourceManagerConnectorConfig.AZURE_RESOURCE_MANAGER.getPropertyName(), createResourceManager()); +// queryContextBuilder.setPropertyProvider( AzureResourceManagerPostgreSqlManagerConnectorConfig.POSTGRESQL_MANAGER.getPropertyName(), +// Main::createPostgreSqlManagers ); +// queryContextBuilder.setProperty( "serverParameters", List.of("ssl_min_protocol_version", "authentication_timeout")); +// queryContextBuilder.setProperty( AzureGraphConnectorConfig.GRAPH_SERVICE_CLIENT.getPropertyName(), createGraphServiceClient()); // queryContextBuilder.setProperty( AwsConnectorConfig.ACCOUNT.getPropertyName(), createAwsAccount() ); // queryContextBuilder.setProperty( GoogleDirectoryConnectorConfig.GOOGLE_DIRECTORY_SERVICE.getPropertyName(), createGoogleDirectory() ); // queryContextBuilder.setProperty( GoogleDriveConnectorConfig.GOOGLE_DRIVE_SERVICE.getPropertyName(), createGoogleDrive() ); @@ -216,11 +210,11 @@ public static void main(String[] args) throws Exception { // queryContextBuilder.setProperty( "jqlQuery", "statusCategory != Done"); // queryContextBuilder.setProperty( JiraCloudAdminConnectorConfig.API_CLIENT.getPropertyName(), createJiraCloudAdminOrganizationApiClient()); queryContextBuilder.setProperty( EntityViewConnectorConfig.ENTITY_VIEW_MANAGER.getPropertyName(), evm ); - queryContextBuilder.setProperty( GitlabConnectorConfig.GITLAB_API.getPropertyName(), createGitlabApi()); - queryContextBuilder.setProperty( GitlabGraphQlConnectorConfig.GITLAB_GRAPHQL_CLIENT.getPropertyName(), createGitlabGraphQLClient()); +// queryContextBuilder.setProperty( GitlabConnectorConfig.GITLAB_API.getPropertyName(), createGitlabApi()); +// queryContextBuilder.setProperty( GitlabGraphQlConnectorConfig.GITLAB_GRAPHQL_CLIENT.getPropertyName(), createGitlabGraphQLClient()); // queryContextBuilder.setProperty(KandjiConnectorConfig.API_CLIENT.getPropertyName(), createKandjiApiClient()); // queryContextBuilder.setProperty(GithubConnectorConfig.GITHUB.getPropertyName(), createGithub()); - queryContextBuilder.setProperty( GitHubConnectorConfig.GITHUB_GRAPHQL_CLIENT.getPropertyName(), createGitHubGraphQLClient()); +// queryContextBuilder.setProperty( GitHubConnectorConfig.GITHUB_GRAPHQL_CLIENT.getPropertyName(), createGitHubGraphQLClient()); // queryContextBuilder.setProperty(com.blazebit.query.connector.github.v0314.GithubConnectorConfig.API_CLIENT.getPropertyName(), createGitHubApiClient()); // Azure Resource manager @@ -385,8 +379,8 @@ public static void main(String[] args) throws Exception { // testGitHub( session ); // testGitHubOpenAPI( session ); // testKandji( session ); -// testEntityView( session ); - testAzureGraph( session ); + testEntityView( session ); +// testAzureGraph( session ); // testAzureResourceManager( session ); } } @@ -931,6 +925,10 @@ private static void testEntityView(QuerySession session) { "select t.id, e.text1 from " + name( TestEntityView.class ) + " t, unnest(t.elements) e" ); List entityViewResult = entityViewQuery.getResultList(); print( entityViewResult, "id", "text1" ); + TypedQuery entityViewQuery2 = session.createQuery( + "select t.id, array_contains(t.enums, 'A') from " + name( TestEntityView.class ) + " t" ); + List entityViewResult2 = entityViewQuery2.getResultList(); + print( entityViewResult2, "id", "enums" ); } private static void testAzureGraph(QuerySession session) { diff --git a/examples/app/src/main/java/com/blazebit/query/app/TestEntity.java b/examples/app/src/main/java/com/blazebit/query/app/TestEntity.java index 8cb6e688..d65ff55b 100644 --- a/examples/app/src/main/java/com/blazebit/query/app/TestEntity.java +++ b/examples/app/src/main/java/com/blazebit/query/app/TestEntity.java @@ -23,15 +23,18 @@ public class TestEntity { TestEmbeddable embedded; @ElementCollection Set elements; + @ElementCollection + Set enums; public TestEntity() { } - public TestEntity(Long id, String name, TestEmbeddable embedded) { + public TestEntity(Long id, String name, TestEmbeddable embedded, Set enums) { this.id = id; this.name = name; this.embedded = embedded; this.elements = new HashSet<>( Collections.singletonList( embedded ) ); + this.enums = new HashSet<>( enums ); } @Override diff --git a/examples/app/src/main/java/com/blazebit/query/app/TestEntityView.java b/examples/app/src/main/java/com/blazebit/query/app/TestEntityView.java index 96d6cfa6..66fe373a 100644 --- a/examples/app/src/main/java/com/blazebit/query/app/TestEntityView.java +++ b/examples/app/src/main/java/com/blazebit/query/app/TestEntityView.java @@ -19,4 +19,6 @@ public interface TestEntityView { TestEmbeddableView getEmbedded(); Set getElements(); + + Set getEnums(); } diff --git a/examples/app/src/main/java/com/blazebit/query/app/TestEnum.java b/examples/app/src/main/java/com/blazebit/query/app/TestEnum.java new file mode 100644 index 00000000..6fe2b9d8 --- /dev/null +++ b/examples/app/src/main/java/com/blazebit/query/app/TestEnum.java @@ -0,0 +1,11 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright Blazebit + */ +package com.blazebit.query.app; + +public enum TestEnum { + A, + B, + C +}