Skip to content

Commit bef596d

Browse files
committed
[FLINK-22653][python][table-planner-blink] Support StreamExecPythonOverAggregate json serialization/deserialization
1 parent 5eebab4 commit bef596d

9 files changed

Lines changed: 2271 additions & 3 deletions

File tree

flink-python/pyflink/table/tests/test_pandas_udaf.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,76 @@ def test_proc_time_over_rows_window_aggregate_function(self):
760760
"+I[3, 2.0, 4]"])
761761
os.remove(source_path)
762762

763+
def test_execute_over_aggregate_from_json_plan(self):
764+
# create source file path
765+
tmp_dir = self.tempdir
766+
data = [
767+
'1,1,2013-01-01 03:10:00',
768+
'3,2,2013-01-01 03:10:00',
769+
'2,1,2013-01-01 03:10:00',
770+
'1,5,2013-01-01 03:10:00',
771+
'1,8,2013-01-01 04:20:00',
772+
'2,3,2013-01-01 03:30:00'
773+
]
774+
source_path = tmp_dir + '/test_execute_over_aggregate_from_json_plan.csv'
775+
sink_path = tmp_dir + '/test_execute_over_aggregate_from_json_plan'
776+
with open(source_path, 'w') as fd:
777+
for ele in data:
778+
fd.write(ele + '\n')
779+
780+
source_table = """
781+
CREATE TABLE source_table (
782+
a TINYINT,
783+
b SMALLINT,
784+
rowtime TIMESTAMP(3),
785+
WATERMARK FOR rowtime AS rowtime - INTERVAL '60' MINUTE
786+
) WITH (
787+
'connector' = 'filesystem',
788+
'path' = '%s',
789+
'format' = 'csv'
790+
)
791+
""" % source_path
792+
self.t_env.execute_sql(source_table)
793+
794+
self.t_env.execute_sql("""
795+
CREATE TABLE sink_table (
796+
a TINYINT,
797+
b FLOAT,
798+
c SMALLINT
799+
) WITH (
800+
'connector' = 'filesystem',
801+
'path' = '%s',
802+
'format' = 'csv'
803+
)
804+
""" % sink_path)
805+
806+
max_add_min_udaf = udaf(lambda a: a.max() + a.min(),
807+
result_type=DataTypes.SMALLINT(),
808+
func_type='pandas')
809+
self.t_env.get_config().get_configuration().set_string(
810+
"pipeline.time-characteristic", "EventTime")
811+
self.t_env.create_temporary_system_function("mean_udaf", mean_udaf)
812+
self.t_env.create_temporary_system_function("max_add_min_udaf", max_add_min_udaf)
813+
814+
json_plan = self.t_env._j_tenv.getJsonPlan("""
815+
insert into sink_table
816+
select a,
817+
mean_udaf(b)
818+
over (PARTITION BY a ORDER BY rowtime
819+
ROWS BETWEEN 1 PRECEDING AND CURRENT ROW),
820+
max_add_min_udaf(b)
821+
over (PARTITION BY a ORDER BY rowtime
822+
ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)
823+
from source_table
824+
""")
825+
from py4j.java_gateway import get_method
826+
get_method(self.t_env._j_tenv.executeJsonPlan(json_plan), "await")()
827+
828+
import glob
829+
lines = [line.strip() for file in glob.glob(sink_path + '/*') for line in open(file, 'r')]
830+
lines.sort()
831+
self.assertEqual(lines, ['1,1.0,2', '1,3.0,6', '1,6.5,13', '2,1.0,2', '2,2.0,4', '3,2.0,4'])
832+
763833

764834
@udaf(result_type=DataTypes.FLOAT(), func_type="pandas")
765835
def mean_udaf(v):

flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecPythonOverAggregate.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
import org.apache.flink.table.types.logical.TimestampKind;
4747
import org.apache.flink.table.types.logical.TimestampType;
4848

49+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
50+
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
51+
4952
import org.apache.calcite.rel.core.AggregateCall;
5053
import org.slf4j.Logger;
5154
import org.slf4j.LoggerFactory;
@@ -54,6 +57,10 @@
5457
import java.lang.reflect.InvocationTargetException;
5558
import java.math.BigDecimal;
5659
import java.util.Collections;
60+
import java.util.List;
61+
62+
import static org.apache.flink.util.Preconditions.checkArgument;
63+
import static org.apache.flink.util.Preconditions.checkNotNull;
5764

5865
/** Stream {@link ExecNode} for python time-based over operator. */
5966
public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData>
@@ -77,15 +84,34 @@ public class StreamExecPythonOverAggregate extends ExecNodeBase<RowData>
7784
"org.apache.flink.table.runtime.operators.python.aggregate.arrow.stream."
7885
+ "StreamArrowPythonProcTimeBoundedRowsOperator";
7986

87+
public static final String FIELD_NAME_OVER_SPEC = "overSpec";
88+
89+
@JsonProperty(FIELD_NAME_OVER_SPEC)
8090
private final OverSpec overSpec;
8191

8292
public StreamExecPythonOverAggregate(
8393
OverSpec overSpec,
8494
InputProperty inputProperty,
8595
RowType outputType,
8696
String description) {
87-
super(Collections.singletonList(inputProperty), outputType, description);
88-
this.overSpec = overSpec;
97+
this(
98+
overSpec,
99+
getNewNodeId(),
100+
Collections.singletonList(inputProperty),
101+
outputType,
102+
description);
103+
}
104+
105+
@JsonCreator
106+
public StreamExecPythonOverAggregate(
107+
@JsonProperty(FIELD_NAME_OVER_SPEC) OverSpec overSpec,
108+
@JsonProperty(FIELD_NAME_ID) int id,
109+
@JsonProperty(FIELD_NAME_INPUT_PROPERTIES) List<InputProperty> inputProperties,
110+
@JsonProperty(FIELD_NAME_OUTPUT_TYPE) RowType outputType,
111+
@JsonProperty(FIELD_NAME_DESCRIPTION) String description) {
112+
super(id, inputProperties, outputType, description);
113+
checkArgument(inputProperties.size() == 1);
114+
this.overSpec = checkNotNull(overSpec);
89115
}
90116

91117
@SuppressWarnings("unchecked")

flink-table/flink-table-planner-blink/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/JsonSerdeCoverageTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ public class JsonSerdeCoverageTest {
4646
"StreamExecWindowTableFunction",
4747
"StreamExecGroupTableAggregate",
4848
"StreamExecPythonGroupTableAggregate",
49-
"StreamExecPythonOverAggregate",
5049
"StreamExecSort",
5150
"StreamExecMultipleInput",
5251
"StreamExecValues");
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.planner.plan.nodes.exec.stream;
20+
21+
import org.apache.flink.table.api.TableConfig;
22+
import org.apache.flink.table.api.TableEnvironment;
23+
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.PandasAggregateFunction;
24+
import org.apache.flink.table.planner.utils.StreamTableTestUtil;
25+
import org.apache.flink.table.planner.utils.TableTestBase;
26+
27+
import org.junit.Before;
28+
import org.junit.Test;
29+
30+
/** Test json serialization for over aggregate. */
31+
public class PythonOverAggregateJsonPlanTest extends TableTestBase {
32+
private StreamTableTestUtil util;
33+
private TableEnvironment tEnv;
34+
35+
@Before
36+
public void setup() {
37+
util = streamTestUtil(TableConfig.getDefault());
38+
tEnv = util.getTableEnv();
39+
String srcTableDdl =
40+
"CREATE TABLE MyTable (\n"
41+
+ " a int,\n"
42+
+ " b varchar,\n"
43+
+ " c int not null,\n"
44+
+ " rowtime timestamp(3),\n"
45+
+ " proctime as PROCTIME(),\n"
46+
+ " watermark for rowtime as rowtime"
47+
+ ") with (\n"
48+
+ " 'connector' = 'values',\n"
49+
+ " 'bounded' = 'false')";
50+
tEnv.executeSql(srcTableDdl);
51+
tEnv.createTemporarySystemFunction("pyFunc", new PandasAggregateFunction());
52+
}
53+
54+
@Test
55+
public void testProcTimeBoundedPartitionedRangeOver() {
56+
String sinkTableDdl =
57+
"CREATE TABLE MySink (\n"
58+
+ " a bigint,\n"
59+
+ " b bigint\n"
60+
+ ") with (\n"
61+
+ " 'connector' = 'values',\n"
62+
+ " 'sink-insert-only' = 'false',\n"
63+
+ " 'table-sink-class' = 'DEFAULT')";
64+
tEnv.executeSql(sinkTableDdl);
65+
String sql =
66+
"insert into MySink SELECT a,\n"
67+
+ " pyFunc(c, c) OVER (PARTITION BY a ORDER BY proctime\n"
68+
+ " RANGE BETWEEN INTERVAL '2' HOUR PRECEDING AND CURRENT ROW)\n"
69+
+ "FROM MyTable";
70+
util.verifyJsonPlan(sql);
71+
}
72+
73+
@Test
74+
public void testProcTimeBoundedNonPartitionedRangeOver() {
75+
String sinkTableDdl =
76+
"CREATE TABLE MySink (\n"
77+
+ " a bigint,\n"
78+
+ " b bigint\n"
79+
+ ") with (\n"
80+
+ " 'connector' = 'values',\n"
81+
+ " 'sink-insert-only' = 'false',\n"
82+
+ " 'table-sink-class' = 'DEFAULT')";
83+
tEnv.executeSql(sinkTableDdl);
84+
String sql =
85+
"insert into MySink SELECT a,\n"
86+
+ " pyFunc(c, c) OVER (ORDER BY proctime\n"
87+
+ " RANGE BETWEEN INTERVAL '10' SECOND PRECEDING AND CURRENT ROW)\n"
88+
+ " FROM MyTable";
89+
util.verifyJsonPlan(sql);
90+
}
91+
92+
@Test
93+
public void testProcTimeUnboundedPartitionedRangeOver() {
94+
String sinkTableDdl =
95+
"CREATE TABLE MySink (\n"
96+
+ " a bigint,\n"
97+
+ " b bigint\n"
98+
+ ") with (\n"
99+
+ " 'connector' = 'values',\n"
100+
+ " 'sink-insert-only' = 'false',\n"
101+
+ " 'table-sink-class' = 'DEFAULT')";
102+
tEnv.executeSql(sinkTableDdl);
103+
String sql =
104+
"insert into MySink SELECT a,\n"
105+
+ " pyFunc(c, c) OVER (PARTITION BY a ORDER BY proctime RANGE UNBOUNDED PRECEDING)\n"
106+
+ "FROM MyTable";
107+
util.verifyJsonPlan(sql);
108+
}
109+
110+
@Test
111+
public void testRowTimeBoundedPartitionedRowsOver() {
112+
String sinkTableDdl =
113+
"CREATE TABLE MySink (\n"
114+
+ " a bigint,\n"
115+
+ " b bigint\n"
116+
+ ") with (\n"
117+
+ " 'connector' = 'values',\n"
118+
+ " 'sink-insert-only' = 'false',\n"
119+
+ " 'table-sink-class' = 'DEFAULT')";
120+
tEnv.executeSql(sinkTableDdl);
121+
String sql =
122+
"insert into MySink SELECT a,\n"
123+
+ " pyFunc(c, c) OVER (PARTITION BY a ORDER BY rowtime\n"
124+
+ " ROWS BETWEEN 5 preceding AND CURRENT ROW)\n"
125+
+ "FROM MyTable";
126+
util.verifyJsonPlan(sql);
127+
}
128+
129+
@Test
130+
public void testProcTimeBoundedPartitionedRowsOverWithBuiltinProctime() {
131+
String sinkTableDdl =
132+
"CREATE TABLE MySink (\n"
133+
+ " a bigint,\n"
134+
+ " b bigint\n"
135+
+ ") with (\n"
136+
+ " 'connector' = 'values',\n"
137+
+ " 'sink-insert-only' = 'false',\n"
138+
+ " 'table-sink-class' = 'DEFAULT')";
139+
tEnv.executeSql(sinkTableDdl);
140+
String sql =
141+
"insert into MySink SELECT a, "
142+
+ " pyFunc(c, c) OVER ("
143+
+ " PARTITION BY a ORDER BY proctime() ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) "
144+
+ "FROM MyTable";
145+
util.verifyJsonPlan(sql);
146+
}
147+
}

0 commit comments

Comments
 (0)