Skip to content

Commit d09745a

Browse files
committed
[FLINK-22650][python][table-planner-blink] Support StreamExecPythonCorrelate json serialization/deserialization
This closes #15922.
1 parent 5fab2da commit d09745a

14 files changed

Lines changed: 1048 additions & 45 deletions

File tree

flink-python/pyflink/table/tests/test_udtf.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,55 @@ class PyFlinkStreamUserDefinedTableFunctionTests(UserDefinedTableFunctionTests,
7878

7979
class PyFlinkBlinkStreamUserDefinedFunctionTests(UserDefinedTableFunctionTests,
8080
PyFlinkBlinkStreamTableTestCase):
81-
pass
81+
def test_execute_from_json_plan(self):
82+
# create source file path
83+
tmp_dir = self.tempdir
84+
data = ['1,1', '3,2', '2,1']
85+
source_path = tmp_dir + '/test_execute_from_json_plan_input.csv'
86+
sink_path = tmp_dir + '/test_execute_from_json_plan_out'
87+
with open(source_path, 'w') as fd:
88+
for ele in data:
89+
fd.write(ele + '\n')
90+
91+
source_table = """
92+
CREATE TABLE source_table (
93+
a BIGINT,
94+
b BIGINT
95+
) WITH (
96+
'connector' = 'filesystem',
97+
'path' = '%s',
98+
'format' = 'csv'
99+
)
100+
""" % source_path
101+
self.t_env.execute_sql(source_table)
102+
103+
self.t_env.execute_sql("""
104+
CREATE TABLE sink_table (
105+
a BIGINT,
106+
b BIGINT,
107+
c BIGINT
108+
) WITH (
109+
'connector' = 'filesystem',
110+
'path' = '%s',
111+
'format' = 'csv'
112+
)
113+
""" % sink_path)
114+
115+
self.t_env.create_temporary_system_function(
116+
"multi_emit", udtf(MultiEmit(), result_types=[DataTypes.BIGINT(), DataTypes.BIGINT()]))
117+
118+
json_plan = self.t_env._j_tenv.getJsonPlan("INSERT INTO sink_table "
119+
"SELECT a, x, y FROM source_table "
120+
"LEFT JOIN LATERAL TABLE(multi_emit(a, b))"
121+
" as T(x, y)"
122+
" ON TRUE")
123+
from py4j.java_gateway import get_method
124+
get_method(self.t_env._j_tenv.executeJsonPlan(json_plan), "await")()
125+
126+
import glob
127+
lines = [line.strip() for file in glob.glob(sink_path + '/*') for line in open(file, 'r')]
128+
lines.sort()
129+
self.assertEqual(lines, ['1,1,0', '2,2,0', '3,3,0', '3,3,1'])
82130

83131

84132
class PyFlinkBlinkBatchUserDefinedFunctionTests(UserDefinedTableFunctionTests,

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/batch/BatchExecPythonCorrelate.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
import org.apache.flink.table.types.logical.RowType;
2626

2727
import org.apache.calcite.rex.RexCall;
28-
import org.apache.calcite.rex.RexNode;
28+
29+
import java.util.Collections;
2930

3031
/** Batch exec node which matches along with join a Python user defined table function. */
3132
public class BatchExecPythonCorrelate extends CommonExecPythonCorrelate
@@ -34,10 +35,15 @@ public class BatchExecPythonCorrelate extends CommonExecPythonCorrelate
3435
public BatchExecPythonCorrelate(
3536
FlinkJoinType joinType,
3637
RexCall invocation,
37-
RexNode condition,
3838
InputProperty inputProperty,
3939
RowType outputType,
4040
String description) {
41-
super(joinType, invocation, condition, inputProperty, outputType, description);
41+
super(
42+
joinType,
43+
invocation,
44+
getNewNodeId(),
45+
Collections.singletonList(inputProperty),
46+
outputType,
47+
description);
4248
}
4349
}

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCorrelate.java

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,37 +38,47 @@
3838
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
3939
import org.apache.flink.table.types.logical.RowType;
4040

41+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnoreProperties;
42+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
43+
4144
import org.apache.calcite.rex.RexCall;
4245
import org.apache.calcite.rex.RexInputRef;
4346
import org.apache.calcite.rex.RexNode;
4447

4548
import java.lang.reflect.Constructor;
46-
import java.util.Collections;
4749
import java.util.LinkedHashMap;
50+
import java.util.List;
51+
52+
import static org.apache.flink.util.Preconditions.checkArgument;
4853

4954
/** Base {@link ExecNode} which matches along with join a Python user defined table function. */
55+
@JsonIgnoreProperties(ignoreUnknown = true)
5056
public abstract class CommonExecPythonCorrelate extends ExecNodeBase<RowData>
5157
implements SingleTransformationTranslator<RowData> {
58+
59+
public static final String FIELD_NAME_JOIN_TYPE = "joinType";
60+
public static final String FIELD_NAME_FUNCTION_CALL = "functionCall";
61+
5262
private static final String PYTHON_TABLE_FUNCTION_OPERATOR_NAME =
5363
"org.apache.flink.table.runtime.operators.python.table.RowDataPythonTableFunctionOperator";
5464

65+
@JsonProperty(FIELD_NAME_JOIN_TYPE)
5566
private final FlinkJoinType joinType;
67+
68+
@JsonProperty(FIELD_NAME_FUNCTION_CALL)
5669
private final RexCall invocation;
5770

5871
public CommonExecPythonCorrelate(
5972
FlinkJoinType joinType,
6073
RexCall invocation,
61-
RexNode condition,
62-
InputProperty inputProperty,
74+
int id,
75+
List<InputProperty> inputProperties,
6376
RowType outputType,
6477
String description) {
65-
super(Collections.singletonList(inputProperty), outputType, description);
78+
super(id, inputProperties, outputType, description);
79+
checkArgument(inputProperties.size() == 1);
6680
this.joinType = joinType;
6781
this.invocation = invocation;
68-
if (joinType == FlinkJoinType.LEFT && condition != null) {
69-
throw new TableException(
70-
"Currently Python correlate does not support conditions in left join.");
71-
}
7282
}
7383

7484
@SuppressWarnings("unchecked")

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonCorrelate.java

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,43 @@
2424
import org.apache.flink.table.runtime.operators.join.FlinkJoinType;
2525
import org.apache.flink.table.types.logical.RowType;
2626

27+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
28+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonIgnoreProperties;
29+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
30+
2731
import org.apache.calcite.rex.RexCall;
2832
import org.apache.calcite.rex.RexNode;
2933

34+
import java.util.Collections;
35+
import java.util.List;
36+
3037
/** Stream exec node which matches along with join a Python user defined table function. */
38+
@JsonIgnoreProperties(ignoreUnknown = true)
3139
public class StreamExecPythonCorrelate extends CommonExecPythonCorrelate
3240
implements StreamExecNode<RowData> {
3341
public StreamExecPythonCorrelate(
3442
FlinkJoinType joinType,
3543
RexCall invocation,
36-
RexNode condition,
3744
InputProperty inputProperty,
3845
RowType outputType,
3946
String description) {
40-
super(joinType, invocation, condition, inputProperty, outputType, description);
47+
this(
48+
joinType,
49+
invocation,
50+
getNewNodeId(),
51+
Collections.singletonList(inputProperty),
52+
outputType,
53+
description);
54+
}
55+
56+
@JsonCreator
57+
public StreamExecPythonCorrelate(
58+
@JsonProperty(FIELD_NAME_JOIN_TYPE) FlinkJoinType joinType,
59+
@JsonProperty(FIELD_NAME_FUNCTION_CALL) RexNode invocation,
60+
@JsonProperty(FIELD_NAME_ID) int id,
61+
@JsonProperty(FIELD_NAME_INPUT_PROPERTIES) List<InputProperty> inputProperties,
62+
@JsonProperty(FIELD_NAME_OUTPUT_TYPE) RowType outputType,
63+
@JsonProperty(FIELD_NAME_DESCRIPTION) String description) {
64+
super(joinType, (RexCall) invocation, id, inputProperties, outputType, description);
4165
}
4266
}

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,17 @@
2929
import org.apache.calcite.plan.hep.HepRelVertex;
3030
import org.apache.calcite.rel.RelNode;
3131
import org.apache.calcite.rel.type.RelDataType;
32+
import org.apache.calcite.rel.type.RelDataTypeField;
3233
import org.apache.calcite.rex.RexBuilder;
3334
import org.apache.calcite.rex.RexCall;
35+
import org.apache.calcite.rex.RexCorrelVariable;
3436
import org.apache.calcite.rex.RexFieldAccess;
3537
import org.apache.calcite.rex.RexInputRef;
3638
import org.apache.calcite.rex.RexNode;
3739
import org.apache.calcite.rex.RexProgram;
3840
import org.apache.calcite.rex.RexProgramBuilder;
3941
import org.apache.calcite.rex.RexUtil;
42+
import org.apache.calcite.rex.RexVisitorImpl;
4043
import org.apache.calcite.sql.validate.SqlValidatorUtil;
4144

4245
import java.util.LinkedList;
@@ -112,10 +115,36 @@ private List<String> createNewFieldNames(
112115
for (int i = 0; i < primitiveFieldCount; i++) {
113116
calcProjects.add(RexInputRef.of(i, rowType));
114117
}
118+
// change RexCorrelVariable to RexInputRef.
119+
RexVisitorImpl<RexNode> visitor =
120+
new RexVisitorImpl<RexNode>(true) {
121+
@Override
122+
public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
123+
RexNode expr = fieldAccess.getReferenceExpr();
124+
if (expr instanceof RexCorrelVariable) {
125+
RelDataTypeField field = fieldAccess.getField();
126+
return new RexInputRef(field.getIndex(), field.getType());
127+
} else {
128+
return rexBuilder.makeFieldAccess(
129+
expr.accept(this), fieldAccess.getField().getIndex());
130+
}
131+
}
132+
};
115133
// add the fields of the extracted rex calls.
116134
Iterator<RexNode> iterator = extractedRexNodes.iterator();
117135
while (iterator.hasNext()) {
118-
calcProjects.add(iterator.next());
136+
RexNode rexNode = iterator.next();
137+
if (rexNode instanceof RexCall) {
138+
RexCall rexCall = (RexCall) rexNode;
139+
List<RexNode> newProjects =
140+
rexCall.getOperands().stream()
141+
.map(x -> x.accept(visitor))
142+
.collect(Collectors.toList());
143+
RexCall newRexCall = rexCall.clone(rexCall.getType(), newProjects);
144+
calcProjects.add(newRexCall);
145+
} else {
146+
calcProjects.add(rexNode);
147+
}
119148
}
120149

121150
List<String> nameList = new LinkedList<>();
@@ -252,18 +281,31 @@ public void onMatch(RelOptRuleCall call) {
252281
mergedCalc.copy(mergedCalc.getTraitSet(), newScan, mergedCalc.getProgram());
253282
}
254283

255-
FlinkLogicalCalc leftCalc =
256-
createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate);
284+
FlinkLogicalCorrelate newCorrelate;
285+
if (extractedRexNodes.size() > 0) {
286+
FlinkLogicalCalc leftCalc =
287+
createNewLeftCalc(left, rexBuilder, extractedRexNodes, correlate);
257288

258-
FlinkLogicalCorrelate newCorrelate =
259-
new FlinkLogicalCorrelate(
260-
correlate.getCluster(),
261-
correlate.getTraitSet(),
262-
leftCalc,
263-
rightNewInput,
264-
correlate.getCorrelationId(),
265-
correlate.getRequiredColumns(),
266-
correlate.getJoinType());
289+
newCorrelate =
290+
new FlinkLogicalCorrelate(
291+
correlate.getCluster(),
292+
correlate.getTraitSet(),
293+
leftCalc,
294+
rightNewInput,
295+
correlate.getCorrelationId(),
296+
correlate.getRequiredColumns(),
297+
correlate.getJoinType());
298+
} else {
299+
newCorrelate =
300+
new FlinkLogicalCorrelate(
301+
correlate.getCluster(),
302+
correlate.getTraitSet(),
303+
left,
304+
rightNewInput,
305+
correlate.getCorrelationId(),
306+
correlate.getRequiredColumns(),
307+
correlate.getJoinType());
308+
}
267309

268310
FlinkLogicalCalc newTopCalc =
269311
createTopCalc(

flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/batch/BatchPhysicalPythonCorrelate.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
*/
1818
package org.apache.flink.table.planner.plan.nodes.physical.batch
1919

20+
import org.apache.flink.table.api.TableException
2021
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
2122
import org.apache.flink.table.planner.plan.nodes.exec.batch.BatchExecPythonCorrelate
2223
import org.apache.flink.table.planner.plan.nodes.exec.{InputProperty, ExecNode}
@@ -64,10 +65,17 @@ class BatchPhysicalPythonCorrelate(
6465
}
6566

6667
override def translateToExecNode(): ExecNode[_] = {
68+
if (condition.orNull != null) {
69+
if (joinType == JoinRelType.LEFT) {
70+
throw new TableException("Currently Python correlate does not support conditions" +
71+
" in left join.")
72+
}
73+
throw new TableException("The condition of BatchPhysicalPythonCorrelate should be null.")
74+
}
75+
6776
new BatchExecPythonCorrelate(
6877
JoinTypeUtil.getFlinkJoinType(joinType),
6978
scan.getCall.asInstanceOf[RexCall],
70-
condition.orNull,
7179
InputProperty.DEFAULT,
7280
FlinkTypeFactory.toLogicalRowType(getRowType),
7381
getRelDetailedDescription

flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/stream/StreamPhysicalPythonCorrelate.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,18 @@
1717
*/
1818
package org.apache.flink.table.planner.plan.nodes.physical.stream
1919

20+
import org.apache.flink.table.api.TableException
2021
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
2122
import org.apache.flink.table.planner.plan.nodes.exec.stream.StreamExecPythonCorrelate
2223
import org.apache.flink.table.planner.plan.nodes.exec.{InputProperty, ExecNode}
2324
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan
25+
import org.apache.flink.table.planner.plan.utils.JoinTypeUtil
26+
2427
import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
2528
import org.apache.calcite.rel.RelNode
2629
import org.apache.calcite.rel.`type`.RelDataType
2730
import org.apache.calcite.rel.core.JoinRelType
2831
import org.apache.calcite.rex.{RexCall, RexNode}
29-
import org.apache.flink.table.planner.plan.utils.JoinTypeUtil
3032

3133
/**
3234
* Flink RelNode which matches along with join a python user defined table function.
@@ -63,10 +65,17 @@ class StreamPhysicalPythonCorrelate(
6365
}
6466

6567
override def translateToExecNode(): ExecNode[_] = {
68+
if (condition.orNull != null) {
69+
if (joinType == JoinRelType.LEFT) {
70+
throw new TableException("Currently Python correlate does not support conditions" +
71+
" in left join.")
72+
}
73+
throw new TableException("The condition of StreamPhysicalPythonCorrelate should be null.")
74+
}
75+
6676
new StreamExecPythonCorrelate(
6777
JoinTypeUtil.getFlinkJoinType(joinType),
6878
scan.getCall.asInstanceOf[RexCall],
69-
condition.orNull,
7079
InputProperty.DEFAULT,
7180
FlinkTypeFactory.toLogicalRowType(getRowType),
7281
getRelDetailedDescription

flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.util.function.Function
2222

2323
import org.apache.calcite.plan.RelOptRule.{any, operand}
2424
import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
25-
import org.apache.calcite.rex.{RexBuilder, RexCall, RexFieldAccess, RexInputRef, RexLocalRef, RexNode, RexProgram}
25+
import org.apache.calcite.rex.{RexBuilder, RexCall, RexCorrelVariable, RexFieldAccess, RexInputRef, RexLocalRef, RexNode, RexProgram}
2626
import org.apache.calcite.sql.validate.SqlValidatorUtil
2727
import org.apache.flink.table.functions.ScalarFunction
2828
import org.apache.flink.table.functions.python.PythonFunctionKind
@@ -393,7 +393,13 @@ private class ScalarFunctionSplitter(
393393
expr match {
394394
case localRef: RexLocalRef if containsPythonCall(program.expandLocalRef(localRef))
395395
=> getExtractedRexFieldAccess(fieldAccess, localRef.getIndex)
396-
case _ => getExtractedRexNode(fieldAccess)
396+
case _: RexCorrelVariable =>
397+
val field = fieldAccess.getField
398+
new RexInputRef(field.getIndex, field.getType)
399+
case _ =>
400+
val newFieldAccess = rexBuilder.makeFieldAccess(
401+
expr.accept(this), fieldAccess.getField.getIndex)
402+
getExtractedRexNode(newFieldAccess)
397403
}
398404
} else {
399405
fieldAccess

flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/JsonSerdeCoverageTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ public class JsonSerdeCoverageTest {
4848
"StreamExecGroupTableAggregate",
4949
"StreamExecPythonGroupTableAggregate",
5050
"StreamExecPythonOverAggregate",
51-
"StreamExecPythonCorrelate",
5251
"StreamExecSort",
5352
"StreamExecMultipleInput",
5453
"StreamExecValues");

0 commit comments

Comments
 (0)