Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion flink-python/pyflink/table/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,55 @@ class PyFlinkStreamUserDefinedTableFunctionTests(UserDefinedTableFunctionTests,

class PyFlinkBlinkStreamUserDefinedFunctionTests(UserDefinedTableFunctionTests,
PyFlinkBlinkStreamTableTestCase):
pass
def test_execute_from_json_plan(self):
# create source file path
tmp_dir = self.tempdir
data = ['1,1', '3,2', '2,1']
source_path = tmp_dir + '/test_execute_from_json_plan_input.csv'
sink_path = tmp_dir + '/test_execute_from_json_plan_out'
with open(source_path, 'w') as fd:
for ele in data:
fd.write(ele + '\n')

source_table = """
CREATE TABLE source_table (
a BIGINT,
b BIGINT
) WITH (
'connector' = 'filesystem',
'path' = '%s',
'format' = 'csv'
)
""" % source_path
self.t_env.execute_sql(source_table)

self.t_env.execute_sql("""
CREATE TABLE sink_table (
a BIGINT,
b BIGINT,
c BIGINT
) WITH (
'connector' = 'filesystem',
'path' = '%s',
'format' = 'csv'
)
""" % sink_path)

self.t_env.create_temporary_system_function(
"multi_emit", udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()]))

json_plan = self.t_env._j_tenv.getJsonPlan("INSERT INTO sink_table "
"SELECT a, x, y FROM source_table "
"LEFT JOIN LATERAL TABLE(multi_emit(a, b))"
" as T(x, y)"
" ON TRUE")
from py4j.java_gateway import get_method
get_method(self.t_env._j_tenv.executeJsonPlan(json_plan), "await")()

import glob
lines = [line.strip() for file in glob.glob(sink_path + '/*') for line in open(file, 'r')]
lines.sort()
self.assertEqual(lines, ['1,1,0', '2,2,0', '3,3,0', '3,3,1'])


class PyFlinkBlinkBatchUserDefinedFunctionTests(UserDefinedTableFunctionTests,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
import org.apache.flink.table.types.logical.RowType;

import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;

import java.util.Collections;

/** Batch exec node which matches along with join a Python user defined table function. */
public class BatchExecPythonCorrelate extends CommonExecPythonCorrelate
Expand All @@ -34,10 +35,15 @@ public class BatchExecPythonCorrelate extends CommonExecPythonCorrelate
public BatchExecPythonCorrelate(
FlinkJoinType joinType,
RexCall invocation,
RexNode condition,
InputProperty inputProperty,
RowType outputType,
String description) {
super(joinType, invocation, condition, inputProperty, outputType, description);
super(
joinType,
invocation,
getNewNodeId(),
Collections.singletonList(inputProperty),
outputType,
description);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,37 +38,47 @@
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.RowType;

import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;

import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;

import java.lang.reflect.Constructor;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;

import static org.apache.flink.util.Preconditions.checkArgument;

/** Base {@link ExecNode} which matches along with join a Python user defined table function. */
@JsonIgnoreProperties(ignoreUnknown = true)
public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
implements SingleTransformationTranslator<RowData> {

public static final String FIELD_NAME_JOIN_TYPE = "joinType";
public static final String FIELD_NAME_FUNCTION_CALL = "functionCall";

private static final String PYTHON_TABLE_FUNCTION_OPERATOR_NAME =
"org.apache.flink.table.runtime.operators.python.table.RowDataPythonTableFunctionOperator";

@JsonProperty(FIELD_NAME_JOIN_TYPE)
private final FlinkJoinType joinType;

@JsonProperty(FIELD_NAME_FUNCTION_CALL)
private final RexCall invocation;

public CommonExecPythonCorrelate(
FlinkJoinType joinType,
RexCall invocation,
RexNode condition,
InputProperty inputProperty,
int id,
List<InputProperty> inputProperties,
RowType outputType,
String description) {
super(Collections.singletonList(inputProperty), outputType, description);
super(id, inputProperties, outputType, description);
checkArgument(inputProperties.size() == 1);
this.joinType = joinType;
this.invocation = invocation;
if (joinType == FlinkJoinType.LEFT && condition != null) {
throw new TableException(
"Currently Python correlate does not support conditions in left join.");
}
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,43 @@
import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
import org.apache.flink.table.types.logical.RowType;

import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;

import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;

import java.util.Collections;
import java.util.List;

/** Stream exec node which matches along with join a Python user defined table function. */
@JsonIgnoreProperties(ignoreUnknown = true)
public class StreamExecPythonCorrelate extends CommonExecPythonCorrelate
implements StreamExecNode<RowData> {
public StreamExecPythonCorrelate(
FlinkJoinType joinType,
RexCall invocation,
RexNode condition,
InputProperty inputProperty,
RowType outputType,
String description) {
super(joinType, invocation, condition, inputProperty, outputType, description);
this(
joinType,
invocation,
getNewNodeId(),
Collections.singletonList(inputProperty),
outputType,
description);
}

@JsonCreator
public StreamExecPythonCorrelate(
@JsonProperty(FIELD_NAME_JOIN_TYPE) FlinkJoinType joinType,
@JsonProperty(FIELD_NAME_FUNCTION_CALL) RexNode invocation,
@JsonProperty(FIELD_NAME_ID) int id,
@JsonProperty(FIELD_NAME_INPUT_PROPERTIES) List<InputProperty> inputProperties,
@JsonProperty(FIELD_NAME_OUTPUT_TYPE) RowType outputType,
@JsonProperty(FIELD_NAME_DESCRIPTION) String description) {
super(joinType, (RexCall) invocation, id, inputProperties, outputType, description);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.validate.SqlValidatorUtil;

import java.util.LinkedList;
Expand Down Expand Up @@ -112,10 +115,36 @@ private List<String> createNewFieldNames(
for (int i = 0; i < primitiveFieldCount; i++) {
calcProjects.add(RexInputRef.of(i, rowType));
}
// change RexCorrelVariable to RexInputRef.
RexVisitorImpl<RexNode> visitor =
new RexVisitorImpl<RexNode>(true) {
@Override
public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
RexNode expr = fieldAccess.getReferenceExpr();
if (expr instanceof RexCorrelVariable) {
RelDataTypeField field = fieldAccess.getField();
return new RexInputRef(field.getIndex(), field.getType());
} else {
return rexBuilder.makeFieldAccess(
expr.accept(this), fieldAccess.getField().getIndex());
}
}
};
// add the fields of the extracted rex calls.
Iterator<RexNode> iterator = extractedRexNodes.iterator();
while (iterator.hasNext()) {
calcProjects.add(iterator.next());
RexNode rexNode = iterator.next();
if (rexNode instanceof RexCall) {
RexCall rexCall = (RexCall) rexNode;
List<RexNode> newProjects =
rexCall.getOperands().stream()
.map(x -> x.accept(visitor))
.collect(Collectors.toList());
RexCall newRexCall = rexCall.clone(rexCall.getType(), newProjects);
calcProjects.add(newRexCall);
} else {
calcProjects.add(rexNode);
}
}

List<String> nameList = new LinkedList<>();
Expand Down Expand Up @@ -252,18 +281,31 @@ public void onMatch(RelOptRuleCall call) {
mergedCalc.copy(mergedCalc.getTraitSet(), newScan, mergedCalc.getProgram());
}

FlinkLogicalCalc leftCalc =
createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate);
FlinkLogicalCorrelate newCorrelate;
if (extractedRexNodes.size() > 0) {
FlinkLogicalCalc leftCalc =
createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate);

FlinkLogicalCorrelate newCorrelate =
new FlinkLogicalCorrelate(
correlate.getCluster(),
correlate.getTraitSet(),
leftCalc,
rightNewInput,
correlate.getCorrelationId(),
correlate.getRequiredColumns(),
correlate.getJoinType());
newCorrelate =
new FlinkLogicalCorrelate(
correlate.getCluster(),
correlate.getTraitSet(),
leftCalc,
rightNewInput,
correlate.getCorrelationId(),
correlate.getRequiredColumns(),
correlate.getJoinType());
} else {
newCorrelate =
new FlinkLogicalCorrelate(
correlate.getCluster(),
correlate.getTraitSet(),
left,
rightNewInput,
correlate.getCorrelationId(),
correlate.getRequiredColumns(),
correlate.getJoinType());
}

FlinkLogicalCalc newTopCalc =
createTopCalc(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.flink.table.planner.plan.nodes.physical.batch

import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecPythonCorrelate
import org.apache.flink.table.planner.plan.nodes.exec.{InputProperty, ExecNode}
Expand Down Expand Up @@ -64,10 +65,17 @@ class BatchPhysicalPythonCorrelate(
}

override def translateToExecNode(): ExecNode[_] = {
if (condition.orNull != null) {
if (joinType == JoinRelType.LEFT) {
throw new TableException("Currently Python correlate does not support conditions" +
" in left join.")
}
throw new TableException("The condition of BatchPhysicalPythonCorrelate should be null.")
}

new BatchExecPythonCorrelate(
JoinTypeUtil.getFlinkJoinType(joinType),
scan.getCall.asInstanceOf[RexCall],
condition.orNull,
InputProperty.DEFAULT,
FlinkTypeFactory.toLogicalRowType(getRowType),
getRelDetailedDescription
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@
*/
package org.apache.flink.table.planner.plan.nodes.physical.stream

import org.apache.flink.table.api.TableException
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecPythonCorrelate
import org.apache.flink.table.planner.plan.nodes.exec.{InputProperty, ExecNode}
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan
import org.apache.flink.table.planner.plan.utils.JoinTypeUtil

import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelNode
import org.apache.calcite.rel.`type`.RelDataType
import org.apache.calcite.rel.core.JoinRelType
import org.apache.calcite.rex.{RexCall, RexNode}
import org.apache.flink.table.planner.plan.utils.JoinTypeUtil

/**
* Flink RelNode which matches along with join a python user defined table function.
Expand Down Expand Up @@ -63,10 +65,17 @@ class StreamPhysicalPythonCorrelate(
}

override def translateToExecNode(): ExecNode[_] = {
if (condition.orNull != null) {
if (joinType == JoinRelType.LEFT) {
throw new TableException("Currently Python correlate does not support conditions" +
" in left join.")
}
throw new TableException("The condition of StreamPhysicalPythonCorrelate should be null.")
}

new StreamExecPythonCorrelate(
JoinTypeUtil.getFlinkJoinType(joinType),
scan.getCall.asInstanceOf[RexCall],
condition.orNull,
InputProperty.DEFAULT,
FlinkTypeFactory.toLogicalRowType(getRowType),
getRelDetailedDescription
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.function.Function

import org.apache.calcite.plan.RelOptRule.{any, operand}
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
import org.apache.calcite.rex.{RexBuilder, RexCall, RexFieldAccess, RexInputRef, RexLocalRef, RexNode, RexProgram}
import org.apache.calcite.rex.{RexBuilder, RexCall, RexCorrelVariable, RexFieldAccess, RexInputRef, RexLocalRef, RexNode, RexProgram}
import org.apache.calcite.sql.validate.SqlValidatorUtil
import org.apache.flink.table.functions.ScalarFunction
import org.apache.flink.table.functions.python.PythonFunctionKind
Expand Down Expand Up @@ -393,7 +393,13 @@ private class ScalarFunctionSplitter(
expr match {
case localRef: RexLocalRef if containsPythonCall(program.expandLocalRef(localRef))
=> getExtractedRexFieldAccess(fieldAccess, localRef.getIndex)
case _ => getExtractedRexNode(fieldAccess)
case _: RexCorrelVariable =>
val field = fieldAccess.getField
new RexInputRef(field.getIndex, field.getType)
case _ =>
val newFieldAccess = rexBuilder.makeFieldAccess(
expr.accept(this), fieldAccess.getField.getIndex)
getExtractedRexNode(newFieldAccess)
}
} else {
fieldAccess
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ public class JsonSerdeCoverageTest {
"StreamExecGroupTableAggregate",
"StreamExecPythonGroupTableAggregate",
"StreamExecPythonOverAggregate",
"StreamExecPythonCorrelate",
"StreamExecSort",
"StreamExecMultipleInput",
"StreamExecValues");
Expand Down
Loading