Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2274,12 +2274,14 @@ class Analyzer(override val catalogManager: CatalogManager)
case Some(m) if Modifier.isStatic(m.getModifiers) =>
StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes,
propagateNull = false, returnNullable = scalarFunc.isResultNullable)
propagateNull = false, returnNullable = scalarFunc.isResultNullable,
isDeterministic = scalarFunc.isDeterministic)
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
arguments, methodInputTypes = declaredInputTypes, propagateNull = false,
returnNullable = scalarFunc.isResultNullable)
returnNullable = scalarFunc.isResultNullable,
isDeterministic = scalarFunc.isDeterministic)
case _ =>
// TODO: handle functions defined in Scala too - in Scala, even if a
// subclass do not override the default method in parent interface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ case class ApplyFunctionExpression(
override def name: String = function.name()
override def dataType: DataType = function.resultType()
override def inputTypes: Seq[AbstractDataType] = function.inputTypes().toSeq
override lazy val deterministic: Boolean = function.isDeterministic &&
children.forall(_.deterministic)

private lazy val reusedRow = new SpecificInternalRow(function.inputTypes())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ object SerializerSupport {
* without invoking the function.
* @param returnNullable When false, indicating the invoked method will always return
* non-null value.
* @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark
* will not apply certain optimizations such as constant folding.
*/
case class StaticInvoke(
staticObject: Class[_],
Expand All @@ -248,7 +250,8 @@ case class StaticInvoke(
arguments: Seq[Expression] = Nil,
inputTypes: Seq[AbstractDataType] = Nil,
propagateNull: Boolean = true,
returnNullable: Boolean = true) extends InvokeLike {
returnNullable: Boolean = true,
isDeterministic: Boolean = true) extends InvokeLike {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this defaults to true?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in majority cases a function is deterministic, so defaulting to true here. This is similar to how we treat propagateNull and returnNullable here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for default true.

Copy link
Member

@HyukjinKwon HyukjinKwon Feb 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might be controversial to backport .. as

  1. V2 expressions are unstable yet.
  2. It could lead to different results in maintenance version upgrade if a user sets isDeterministic to false
  3. Maybe performance regression if a user sets isDeterministic to false.
  4. We haven't heard an actual complaint from user mailing list.
  5. This makes an API (isDeterministic) working that has never been working in a way (is this a bug fix?)

While I agree with this being merged in 3.3.0, and I don't feel strongly on this in 3.2.X, maybe we can consider reverting this out of branch-3.2 because it has a good and bad thing. If we're worried about this change, we could issue a warning instead when isDeterministic from V2 Scalar Function returns false.

I will leave it to you @sunchao.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a bug fix. A V2 catalog can return a function that's non-deterministic, while without the fix Spark can treat it as deterministic and apply related optimization rules (e.g., constant folding), which could cause correctness issues.

Since this is already in Spark 3.2.1, I don't see much benefit of reverting it and re-introduce the correctness issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on keeping this in 3.2.x. This fixed the correctness issue and we actually intentionally included this fix in 3.2.1 release.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine. I didn't have a strong preference so I am okay with keeping it either 👍 .


val objectName = staticObject.getName.stripSuffix("$")
val cls = if (staticObject.getName == objectName) {
Expand All @@ -259,6 +262,7 @@ case class StaticInvoke(

override def nullable: Boolean = needNullCheck || returnNullable
override def children: Seq[Expression] = arguments
override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: seems we can move this to InvokeLike

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this too, but then we'd have to move isDeterministic there too and then override it in StaticInvoke and Invoke, so doesn't seem we can save much. Plus I think deterministic property is not useful for NewInstance.


lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
@transient lazy val method = findMethod(cls, functionName, argClasses)
Expand Down Expand Up @@ -340,6 +344,8 @@ case class StaticInvoke(
* without invoking the function.
* @param returnNullable When false, indicating the invoked method will always return
* non-null value.
* @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark
* will not apply certain optimizations such as constant folding.
*/
case class Invoke(
targetObject: Expression,
Expand All @@ -348,12 +354,14 @@ case class Invoke(
arguments: Seq[Expression] = Nil,
methodInputTypes: Seq[AbstractDataType] = Nil,
propagateNull: Boolean = true,
returnNullable : Boolean = true) extends InvokeLike {
returnNullable : Boolean = true,
isDeterministic: Boolean = true) extends InvokeLike {

lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)

override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
override def children: Seq[Expression] = targetObject +: arguments
override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic)
override def inputTypes: Seq[AbstractDataType] =
if (methodInputTypes.nonEmpty) {
Seq(targetObject.dataType) ++ methodInputTypes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public String description() {
return "long_add";
}

private abstract static class JavaLongAddBase implements ScalarFunction<Long> {
public abstract static class JavaLongAddBase implements ScalarFunction<Long> {
private final boolean isResultNullable;

JavaLongAddBase(boolean isResultNullable) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package test.org.apache.spark.sql.connector.catalog.functions;

import java.util.Random;

import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.IntegerType;
import org.apache.spark.sql.types.StructType;

/**
* Test V2 function which add a random number to the input integer.
*/
public class JavaRandomAdd implements UnboundFunction {
private final BoundFunction fn;

public JavaRandomAdd(BoundFunction fn) {
this.fn = fn;
}

@Override
public String name() {
return "rand";
}

@Override
public BoundFunction bind(StructType inputType) {
if (inputType.fields().length != 1) {
throw new UnsupportedOperationException("Expect exactly one argument");
}
if (inputType.fields()[0].dataType() instanceof IntegerType) {
return fn;
}
throw new UnsupportedOperationException("Expect IntegerType");
}

@Override
public String description() {
return "rand_add: add a random integer to the input\n" +
"rand_add(int) -> int";
}

public abstract static class JavaRandomAddBase implements ScalarFunction<Integer> {
@Override
public DataType[] inputTypes() {
return new DataType[] { DataTypes.IntegerType };
}

@Override
public DataType resultType() {
return DataTypes.IntegerType;
}

@Override
public String name() {
return "rand_add";
}

@Override
public boolean isDeterministic() {
return false;
}
}

public static class JavaRandomAddDefault extends JavaRandomAddBase {
private final Random rand = new Random();

@Override
public Integer produceResult(InternalRow input) {
return input.getInt(0) + rand.nextInt();
}
}

public static class JavaRandomAddMagic extends JavaRandomAddBase {
private final Random rand = new Random();

public int invoke(int input) {
return input + rand.nextInt();
}
}

public static class JavaRandomAddStaticMagic extends JavaRandomAddBase {
private static final Random rand = new Random();

public static int invoke(int input) {
return input + rand.nextInt();
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public BoundFunction bind(StructType inputType) {
return fn;
}

throw new UnsupportedOperationException("Except StringType");
throw new UnsupportedOperationException("Expect StringType");
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ package org.apache.spark.sql.connector
import java.util
import java.util.Collections

import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaLongAdd, JavaStrLen}
import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault, JavaLongAddMagic, JavaLongAddMismatchMagic, JavaLongAddStaticMagic}
import test.org.apache.spark.sql.connector.catalog.functions._
import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd._
import test.org.apache.spark.sql.connector.catalog.functions.JavaRandomAdd._
import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._

import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK, NO_CODEGEN}
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction, _}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -428,6 +430,31 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
}
}

test("SPARK-37957: pass deterministic flag when creating V2 function expression") {
def checkDeterministic(df: DataFrame): Unit = {
val result = df.queryExecution.executedPlan.find(_.isInstanceOf[ProjectExec])
assert(result.isDefined, s"Expect to find ProjectExec")
assert(!result.get.asInstanceOf[ProjectExec].projectList.exists(_.deterministic),
"Expect expressions in projectList to be non-deterministic")
}

catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
Seq(new JavaRandomAddDefault, new JavaRandomAddMagic,
new JavaRandomAddStaticMagic).foreach { fn =>
addFunction(Identifier.of(Array("ns"), "rand_add"), new JavaRandomAdd(fn))
checkDeterministic(sql("SELECT testcat.ns.rand_add(42)"))
}

// A function call is non-deterministic if one of its arguments is non-deterministic
Seq(new JavaLongAddDefault(true), new JavaLongAddMagic(true),
new JavaLongAddStaticMagic(true)).foreach { fn =>
addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(fn))
addFunction(Identifier.of(Array("ns"), "rand_add"),
new JavaRandomAdd(new JavaRandomAddDefault))
checkDeterministic(sql("SELECT testcat.ns.add(10, testcat.ns.rand_add(42))"))
}
}

private case class StrLen(impl: BoundFunction) extends UnboundFunction {
override def description(): String =
"""strlen: returns the length of the input string
Expand Down