Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package com.blazebit.query.impl.calcite;

import com.blazebit.query.impl.calcite.function.ArrayContainsFunction;
import org.apache.calcite.adapter.enumerable.EnumerableConvention;
import org.apache.calcite.adapter.enumerable.EnumerableRel;
import org.apache.calcite.adapter.java.JavaTypeFactory;
Expand All @@ -16,6 +17,7 @@
import org.apache.calcite.jdbc.CalcitePrepare;
import org.apache.calcite.jdbc.CalciteSchema;
import org.apache.calcite.jdbc.Driver;
import org.apache.calcite.linq4j.Queryable;
import org.apache.calcite.plan.Convention;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptPlanner;
Expand Down Expand Up @@ -73,6 +75,8 @@ public boolean shouldConvertRaggedUnionTypesToVarying() {
}
this.typeFactory = new CustomJavaTypeFactory( typeSystem );
this.rootSchema = CalciteSchema.createRootSchema( true );
// Custom functions for Blaze-Query
this.rootSchema.plus().add( "array_contains", new ArrayContainsFunction());
}

public SchemaPlus getRootSchema() {
Expand Down Expand Up @@ -160,6 +164,19 @@ public <T> T unwrap(Class<T> iface) throws SQLException {
}

private static class MyCalcitePrepareImpl extends CalcitePrepareImpl {

@Override public <T> CalciteSignature<T> prepareQueryable(Context context, Queryable<T> queryable) {
return super.prepareQueryable(wrapContext(context), queryable);
}

@Override public <T> CalciteSignature<T> prepareSql(Context context, Query<T> query, Type elementType, long maxRowCount) {
return super.prepareSql(wrapContext(context), query, elementType, maxRowCount);
}

private Context wrapContext(Context context) {
return context;
}

@Override
protected CalcitePreparingStmt getPreparingStmt(Context context, Type elementType, CalciteCatalogReader catalogReader, RelOptPlanner planner) {
final JavaTypeFactory typeFactory = context.getTypeFactory();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Blazebit
*/
package com.blazebit.query.impl.calcite.function;

import com.google.common.collect.ImmutableList;
import org.apache.calcite.adapter.enumerable.CallImplementor;
import org.apache.calcite.adapter.enumerable.RexImpTable;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.runtime.FlatLists;
import org.apache.calcite.schema.FunctionParameter;
import org.apache.calcite.schema.ImplementableFunction;
import org.apache.calcite.schema.ScalarFunction;
import org.apache.calcite.sql.SqlCollation;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.type.SqlTypeUtil;

import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.nio.charset.Charset;
import java.util.Comparator;
import java.util.List;

import static org.apache.calcite.adapter.enumerable.EnumUtils.generateCollatorExpression;
import static org.apache.calcite.linq4j.Nullness.castNonNull;
import static org.apache.calcite.util.Static.RESOURCE;

/**
* @author Christian Beikov
* @since 1.0.0
*/
public class ArrayContainsFunction implements ScalarFunction, ImplementableFunction, CallImplementor {

private final Method basicMethod;
private final Method decimalMethod;
private final Method comparatorMethod;
private final List<FunctionParameter> parameters = ImmutableList.of(
new SimpleFunctionParameter(0, "array", Object.class, false),
new SimpleFunctionParameter(1, "element", Object.class, false)
);

public ArrayContainsFunction() {
try {
this.basicMethod = ArrayContainsFunction.class.getDeclaredMethod("arrayContains", List.class, Object.class);
this.decimalMethod = ArrayContainsFunction.class.getDeclaredMethod("arrayContains", List.class, BigDecimal.class);
this.comparatorMethod = ArrayContainsFunction.class.getDeclaredMethod("arrayContains", List.class, String.class, Comparator.class);
}
catch (NoSuchMethodException e) {
throw new RuntimeException( e );
}
}

@Override
public List<FunctionParameter> getParameters() {
return parameters;
}

@Override
public RelDataType getReturnType(RelDataTypeFactory typeFactory) {
return typeFactory.createJavaType(Boolean.class);
}

@Override
public CallImplementor getImplementor() {
return this;
}

@Override
public Expression implement(RexToLixTranslator translator, RexCall call, RexImpTable.NullAs nullAs) {
final List<Expression> expressions = translator.translateList(call.getOperands());
final RelDataType operandType0 = castNonNull(call.getOperands().get(0).getType().getComponentType());
final RelDataType operandType1 = call.getOperands().get(1).getType();
final Expression fieldComparator;
if (SqlTypeUtil.inCharFamily(operandType0) && SqlTypeUtil.inCharFamily(operandType1)) {
Charset cs0 = operandType0.getCharset();
Charset cs1 = operandType1.getCharset();
assert (null != cs0) && (null != cs1)
: "An implicit or explicit charset should have been set";
if (!cs0.equals(cs1)) {
throw SqlUtil.newContextException(call.pos, RESOURCE.incompatibleCharset("array_contains", cs0.name(), cs1.name()));
}

SqlCollation collation0 = operandType0.getCollation();
SqlCollation collation1 = operandType1.getCollation();
assert (null != collation0) && (null != collation1)
: "An implicit or explicit collation should have been set";

// Validation will occur inside getCoercibilityDyadicOperator...
SqlCollation resultCol = SqlCollation.getCoercibilityDyadicOperator(collation0, collation1);
fieldComparator = generateCollatorExpression(resultCol);
}
else {
fieldComparator = null;
}
if (expressions.get(1).getType() == BigDecimal.class) {
return Expressions.call(decimalMethod, expressions);
}

return fieldComparator == null
? Expressions.call(basicMethod, expressions)
: Expressions.call(comparatorMethod, FlatLists.append(expressions, fieldComparator));
}

public static Boolean arrayContains(List<?> list, Object element) {
return list == null ? null : list.contains(element);
}

public static Boolean arrayContains(List<BigDecimal> list, BigDecimal element) {
if (list == null) {
return null;
}
for (BigDecimal t : list) {
if (t.compareTo(element) == 0) {
return true;
}
}
return false;
}

public static Boolean arrayContains(List<String> list, String element, Comparator<String> comparator) {
if (list == null) {
return null;
}
for (String t : list) {
if (comparator.compare(t, element) == 0) {
return true;
}
}
return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright Blazebit
*/
package com.blazebit.query.impl.calcite.function;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.schema.FunctionParameter;

/**
* @author Christian Beikov
* @since 1.0.0
*/
public class SimpleFunctionParameter implements FunctionParameter {

private final int ordinal;
private final String name;
private final Class<?> type;
private final boolean optional;

public SimpleFunctionParameter(int ordinal, String name, Class<?> type, boolean optional) {
this.ordinal = ordinal;
this.name = name;
this.type = type;
this.optional = optional;
}

@Override
public int getOrdinal() {
return ordinal;
}

@Override
public String getName() {
return name;
}

@Override
public RelDataType getType(RelDataTypeFactory typeFactory) {
return typeFactory.createJavaType(type);
}

@Override
public boolean isOptional() {
return optional;
}
}
13 changes: 10 additions & 3 deletions core/impl/src/test/java/com/blazebit/query/impl/EnumArrayTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ public class EnumArrayTest {
public void test() {
JiraCloudAdminUser model = new JiraCloudAdminUser(
"u1",
List.of( PlatformRole.ADMIN )
List.of( PlatformRole.ADMIN ),

List.of("u1", "admin")
);
QueryContextBuilder queryContextBuilder = Queries.createQueryContextBuilder();
queryContextBuilder.registerSchemaObject( JiraCloudAdminUser.class, new DataFetcher<>() {
Expand All @@ -49,19 +51,24 @@ public List<JiraCloudAdminUser> fetch(DataFetchContext context) {
try (QueryContext queryContext = queryContextBuilder.build()) {
try (QuerySession session = queryContext.createSession()) {
TypedQuery<Object[]> query = session.createQuery(
"select r = 'ADMIN' " +
"select r = 'ADMIN', array_contains(u.platformRoles, 'ADMIN'), array_contains(u.aliases, 'ADMIN') " +
"from JiraCloudAdminUser u " +
"cross join unnest(u.platformRoles) r"
);
List<Object[]> result = query.getResultList();
assertEquals( 1, result.size() );
assertEquals( 3, result.get( 0 ).length );
assertEquals( true, result.get( 0 )[0] );
assertEquals( true, result.get( 0 )[1] );
assertEquals( false, result.get( 0 )[2] );
}
}
}

public record JiraCloudAdminUser(
String id,
List<PlatformRole> platformRoles) {
List<PlatformRole> platformRoles,
List<String> aliases) {
}

enum PlatformRole {
Expand Down
32 changes: 15 additions & 17 deletions examples/app/src/main/java/com/blazebit/query/app/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,13 @@
import com.blazebit.query.connector.azure.graph.AzureGraphApplication;
import com.blazebit.query.connector.azure.graph.AzureGraphClientAccessor;
import com.blazebit.query.connector.azure.graph.AzureGraphConditionalAccessPolicy;
import com.blazebit.query.connector.azure.graph.AzureGraphConnectorConfig;
import com.blazebit.query.connector.azure.graph.AzureGraphIncident;
import com.blazebit.query.connector.azure.graph.AzureGraphManagedDevice;
import com.blazebit.query.connector.azure.graph.AzureGraphOrganization;
import com.blazebit.query.connector.azure.graph.AzureGraphServicePlanInfo;
import com.blazebit.query.connector.azure.graph.AzureGraphUser;
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceBlobServiceProperties;
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceManagedCluster;
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceManagerConnectorConfig;
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceManagerPostgreSqlManagerConnectorConfig;
import com.blazebit.query.connector.azure.resourcemanager.AzureResourcePostgreSqlFlexibleServer;
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceManagerPostgreSqlManager;
import com.blazebit.query.connector.azure.resourcemanager.AzureResourcePostgreSqlFlexibleServerBackup;
Expand All @@ -78,7 +75,6 @@
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceVirtualMachine;
import com.blazebit.query.connector.azure.resourcemanager.AzureResourceVirtualNetwork;
import com.blazebit.query.connector.github.graphql.GitHubBranchProtectionRule;
import com.blazebit.query.connector.github.graphql.GitHubConnectorConfig;
import com.blazebit.query.connector.github.graphql.GitHubGraphQlClient;
import com.blazebit.query.connector.github.graphql.GitHubOrganization;
import com.blazebit.query.connector.github.graphql.GitHubPullRequest;
Expand All @@ -87,9 +83,7 @@
import com.blazebit.query.connector.github.v0314.model.OrganizationSimple;
import com.blazebit.query.connector.github.v0314.model.ShortBranch;
import com.blazebit.query.connector.github.v0314.model.Team;
import com.blazebit.query.connector.gitlab.GitlabConnectorConfig;
import com.blazebit.query.connector.gitlab.GitlabGraphQlClient;
import com.blazebit.query.connector.gitlab.GitlabGraphQlConnectorConfig;
import com.blazebit.query.connector.gitlab.GitlabGroup;
import com.blazebit.query.connector.gitlab.GitlabMergeRequest;
import com.blazebit.query.connector.gitlab.GitlabProject;
Expand Down Expand Up @@ -192,7 +186,7 @@ public static void main(String[] args) throws Exception {
try (EntityManagerFactory emf = Persistence.createEntityManagerFactory( "default" )) {
SessionFactory sf = emf.unwrap( SessionFactory.class );
sf.inTransaction( s -> {
s.persist( new TestEntity( 1L, "Test", new TestEmbeddable( "text1", "text2" ) ) );
s.persist( new TestEntity( 1L, "Test", new TestEmbeddable( "text1", "text2" ), Set.of(TestEnum.A, TestEnum.B) ) );
} );

CriteriaBuilderFactory cbf = Criteria.getDefault().createCriteriaBuilderFactory( emf );
Expand All @@ -202,11 +196,11 @@ public static void main(String[] args) throws Exception {
EntityViewManager evm = defaultConfiguration.createEntityViewManager( cbf );

QueryContextBuilder queryContextBuilder = Queries.createQueryContextBuilder();
queryContextBuilder.setProperty( AzureResourceManagerConnectorConfig.AZURE_RESOURCE_MANAGER.getPropertyName(), createResourceManager());
queryContextBuilder.setPropertyProvider( AzureResourceManagerPostgreSqlManagerConnectorConfig.POSTGRESQL_MANAGER.getPropertyName(),
Main::createPostgreSqlManagers );
queryContextBuilder.setProperty( "serverParameters", List.of("ssl_min_protocol_version", "authentication_timeout"));
queryContextBuilder.setProperty( AzureGraphConnectorConfig.GRAPH_SERVICE_CLIENT.getPropertyName(), createGraphServiceClient());
// queryContextBuilder.setProperty( AzureResourceManagerConnectorConfig.AZURE_RESOURCE_MANAGER.getPropertyName(), createResourceManager());
// queryContextBuilder.setPropertyProvider( AzureResourceManagerPostgreSqlManagerConnectorConfig.POSTGRESQL_MANAGER.getPropertyName(),
// Main::createPostgreSqlManagers );
// queryContextBuilder.setProperty( "serverParameters", List.of("ssl_min_protocol_version", "authentication_timeout"));
// queryContextBuilder.setProperty( AzureGraphConnectorConfig.GRAPH_SERVICE_CLIENT.getPropertyName(), createGraphServiceClient());
// queryContextBuilder.setProperty( AwsConnectorConfig.ACCOUNT.getPropertyName(), createAwsAccount() );
// queryContextBuilder.setProperty( GoogleDirectoryConnectorConfig.GOOGLE_DIRECTORY_SERVICE.getPropertyName(), createGoogleDirectory() );
// queryContextBuilder.setProperty( GoogleDriveConnectorConfig.GOOGLE_DRIVE_SERVICE.getPropertyName(), createGoogleDrive() );
Expand All @@ -216,11 +210,11 @@ public static void main(String[] args) throws Exception {
// queryContextBuilder.setProperty( "jqlQuery", "statusCategory != Done");
// queryContextBuilder.setProperty( JiraCloudAdminConnectorConfig.API_CLIENT.getPropertyName(), createJiraCloudAdminOrganizationApiClient());
queryContextBuilder.setProperty( EntityViewConnectorConfig.ENTITY_VIEW_MANAGER.getPropertyName(), evm );
queryContextBuilder.setProperty( GitlabConnectorConfig.GITLAB_API.getPropertyName(), createGitlabApi());
queryContextBuilder.setProperty( GitlabGraphQlConnectorConfig.GITLAB_GRAPHQL_CLIENT.getPropertyName(), createGitlabGraphQLClient());
// queryContextBuilder.setProperty( GitlabConnectorConfig.GITLAB_API.getPropertyName(), createGitlabApi());
// queryContextBuilder.setProperty( GitlabGraphQlConnectorConfig.GITLAB_GRAPHQL_CLIENT.getPropertyName(), createGitlabGraphQLClient());
// queryContextBuilder.setProperty(KandjiConnectorConfig.API_CLIENT.getPropertyName(), createKandjiApiClient());
// queryContextBuilder.setProperty(GithubConnectorConfig.GITHUB.getPropertyName(), createGithub());
queryContextBuilder.setProperty( GitHubConnectorConfig.GITHUB_GRAPHQL_CLIENT.getPropertyName(), createGitHubGraphQLClient());
// queryContextBuilder.setProperty( GitHubConnectorConfig.GITHUB_GRAPHQL_CLIENT.getPropertyName(), createGitHubGraphQLClient());
// queryContextBuilder.setProperty(com.blazebit.query.connector.github.v0314.GithubConnectorConfig.API_CLIENT.getPropertyName(), createGitHubApiClient());

// Azure Resource manager
Expand Down Expand Up @@ -385,8 +379,8 @@ public static void main(String[] args) throws Exception {
// testGitHub( session );
// testGitHubOpenAPI( session );
// testKandji( session );
// testEntityView( session );
testAzureGraph( session );
testEntityView( session );
// testAzureGraph( session );
// testAzureResourceManager( session );
}
}
Expand Down Expand Up @@ -931,6 +925,10 @@ private static void testEntityView(QuerySession session) {
"select t.id, e.text1 from " + name( TestEntityView.class ) + " t, unnest(t.elements) e" );
List<Object[]> entityViewResult = entityViewQuery.getResultList();
print( entityViewResult, "id", "text1" );
TypedQuery<Object[]> entityViewQuery2 = session.createQuery(
"select t.id, array_contains(t.enums, 'A') from " + name( TestEntityView.class ) + " t" );
List<Object[]> entityViewResult2 = entityViewQuery2.getResultList();
print( entityViewResult2, "id", "enums" );
}

private static void testAzureGraph(QuerySession session) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,18 @@ public class TestEntity {
TestEmbeddable embedded;
@ElementCollection
Set<TestEmbeddable> elements;
@ElementCollection
Set<TestEnum> enums;

public TestEntity() {
}

public TestEntity(Long id, String name, TestEmbeddable embedded) {
public TestEntity(Long id, String name, TestEmbeddable embedded, Set<TestEnum> enums) {
this.id = id;
this.name = name;
this.embedded = embedded;
this.elements = new HashSet<>( Collections.singletonList( embedded ) );
this.enums = new HashSet<>( enums );
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@ public interface TestEntityView {
TestEmbeddableView getEmbedded();

Set<TestEmbeddableView> getElements();

Set<TestEnum> getEnums();
}
Loading