Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions connector/connect/client/jvm/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,13 @@
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<executions>
<execution>
<id>default-jar</id>
<phase>compile</phase>
<goals>
<goal>jar</goal>
</goals>
</execution>
<execution>
<id>prepare-test-jar</id>
<phase>test-compile</phase>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql

import java.util.Arrays

import org.apache.spark.sql.connect.client.SparkResult
import org.apache.spark.sql.connect.client.util.RemoteSparkSession
import org.apache.spark.sql.types.{StringType, StructField, StructType}

class UDFClassLoadingE2ESuite extends RemoteSparkSession {

test("load udf with default stub class loader") {
val rows = spark.range(10).filter(n => n % 2 == 0).collectAsList()
assert(rows == Arrays.asList[Long](0, 2, 4, 6, 8))
}

test("update class loader after stubbing: new session") {
// Session1 uses Stub SparkResult class
val session1 = spark.newSession()
addClientTestArtifactInServerClasspath(session1)
val ds = session1.range(10).filter(n => n % 2 == 0)

val rows = ds.collectAsList()
assert(rows == Arrays.asList[Long](0, 2, 4, 6, 8))

// Session2 uses the real SparkResult class
val session2 = spark.newSession()
addClientTestArtifactInServerClasspath(session2)
addClientTestArtifactInServerClasspath(session2, testJar = false)
val rows2 = session2
.range(10)
.filter(n => {
// Try to use spark result
new SparkResult[Int](null, null, null)
n > 5
})
.collectAsList()
assert(rows2 == Arrays.asList[Long](6, 7, 8, 9))
}

test("update class loader after stubbing: same session") {
val session = spark.newSession()
addClientTestArtifactInServerClasspath(session)
val ds = session.range(10).filter(n => n % 2 == 0)

// load SparkResult as a stubbed class
val rows = ds.collectAsList()
assert(rows == Arrays.asList[Long](0, 2, 4, 6, 8))

// Upload SparkResult and then SparkResult can be used in the udf
addClientTestArtifactInServerClasspath(session, testJar = false)
val rows2 = session.range(10).filter(n => {
// Try to use spark result
new SparkResult[Int](null, null, null)
n > 5
}).collectAsList()
assert(rows2 == Arrays.asList[Long](6, 7, 8, 9))
}

// This dummy method generates a lambda in the test class with SparkResult in its signature.
// This will cause class loading issue on the server side as the client jar is
// not in the server classpath.
def dummyMethod(): Unit = {
val df = spark.sql("select val from (values ('Hello'), ('World')) as t(val)")
df.withResult { result =>
val schema = result.schema
assert(schema == StructType(StructField("val", StringType, nullable = false) :: Nil))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ object IntegrationTestUtils {
// System properties used for testing and debugging
private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client"
// Enable this flag to print all client debug log + server logs to the console
private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean
private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "true").toBoolean

private[sql] lazy val scalaVersion = {
versionNumberString.split('.') match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,7 @@ object SparkConnectServerUtils {
.map(clientTestJar => Seq("--jars", clientTestJar.getCanonicalPath))
.getOrElse(Seq.empty)

// For UDF maven E2E tests, the server needs the client code to find the UDFs defined in tests.
val connectClientTestJar = tryFindJar(
"connector/connect/client/jvm",
// SBT passes the client & test jars to the server process automatically.
// So we skip building or finding this jar for SBT.
"sbt-tests-do-not-need-this-jar",
"spark-connect-client-jvm",
test = true)
.map(clientTestJar => Seq(clientTestJar.getCanonicalPath))
.getOrElse(Seq.empty)

val allJars = catalystTestJar ++ connectClientTestJar
val jarsConfigs = Seq("--jars", allJars.mkString(","))
val jarsConfigs = Seq("--jars", catalystTestJar.mkString(","))

// Use InMemoryTableCatalog for V2 writer tests
val writerV2Configs = Seq(
Expand Down Expand Up @@ -211,6 +199,23 @@ trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll {
throw error
}
}

addClientTestArtifactInServerClasspath(spark)
}

// For UDF maven E2E tests, the server needs the client test code to find the UDFs defined in
// tests.
private[sql] def addClientTestArtifactInServerClasspath(
session: SparkSession,
testJar: Boolean = true): Unit = {
tryFindJar(
"connector/connect/client/jvm",
// SBT passes the client & test jars to the server process automatically.
// So we skip building or finding this jar for SBT.
"sbt-tests-do-not-need-this-jar",
"spark-connect-client-jvm",
test = testJar
).foreach(clientTestJar => session.addArtifact(clientTestJar.getCanonicalPath))
}

override def afterAll(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.connect.artifact

import java.io.File
import java.net.{URI, URL, URLClassLoader}
import java.net.{URI, URL}
import java.nio.file.{Files, Path, Paths, StandardCopyOption}
import java.util.concurrent.CopyOnWriteArrayList
import javax.ws.rs.core.UriBuilder
Expand All @@ -31,12 +31,13 @@ import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath}

import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.CONNECT_SCALA_UDF_STUB_CLASSES
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
import org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.storage.{CacheId, StorageLevel}
import org.apache.spark.util.Utils
import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader, Utils}

/**
* The Artifact Manager for the [[SparkConnectService]].
Expand Down Expand Up @@ -161,7 +162,9 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
*/
def classloader: ClassLoader = {
val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL
new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
val stubClassLoader =
StubClassLoader(null, SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_CLASSES))
new ChildFirstURLClassLoader(urls.toArray, stubClassLoader, Utils.getContextOrSparkClassLoader)
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.artifact

// To generate a jar from the source file:
// `scalac StubClassDummyUdf.scala -d udf.jar`
// To remove class A from the jar:
// `jar -xvf udf.jar` -> delete A.class and A$.class
// `jar -cvf udf_noA.jar org/`
class StubClassDummyUdf {
val udf: Int => Int = (x: Int) => x + 1
val dummy = (x: Int) => A(x)
}

case class A(x: Int) { def get: Int = x + 5 }

// The code to generate the udf file
object StubClassDummyUdf {
import java.io.{BufferedOutputStream, File, FileOutputStream}
import org.apache.spark.sql.connect.common.UdfPacket
import org.apache.spark.util.Utils

def packDummyUdf(): String = {
val byteArray =
Utils.serialize[UdfPacket](
new UdfPacket(
new StubClassDummyUdf().udf,
Seq.empty,
null
)
)
val file = new File("src/test/resources/udf")
val target = new BufferedOutputStream(new FileOutputStream(file))
try {
target.write(byteArray)
file.getAbsolutePath
} finally {
target.close
}
}
}
Binary file added connector/connect/server/src/test/resources/udf
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.artifact

import java.io.File
import java.nio.file.{Files, Paths}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.connect.common.UdfPacket
import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader, Utils}

class StubClassLoaderSuite extends SparkFunSuite {

private val udfByteArray: Array[Byte] = Files.readAllBytes(Paths.get("src/test/resources/udf"))
private val udfNoAJar = new File("src/test/resources/udf_noA.jar").toURI.toURL

test("find class with stub class") {
val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => true)
val cls = cl.findClass("my.name.HelloWorld")
assert(cls.getName === "my.name.HelloWorld")
assert(cl.lastStubbed === "my.name.HelloWorld")
}

test("class for name with stub class") {
val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => true)
// scalastyle:off classforname
val cls = Class.forName("my.name.HelloWorld", false, cl)
// scalastyle:on classforname
assert(cls.getName === "my.name.HelloWorld")
assert(cl.lastStubbed === "my.name.HelloWorld")
}

test("filter class to stub") {
val list = "my.name" :: Nil
val cl = StubClassLoader(getClass().getClassLoader(), list)
val cls = cl.findClass("my.name.HelloWorld")
assert(cls.getName === "my.name.HelloWorld")

intercept[ClassNotFoundException] {
cl.findClass("name.my.GoodDay")
}
}

test("load udf") {
// See src/test/resources/StubClassDummyUdf for how the udf and jar is created.
val sysClassLoader = getClass.getClassLoader()
val stubClassLoader = new RecordedStubClassLoader(null, _ => true)

// Install artifact without class A.
val sessionClassLoader = new ChildFirstURLClassLoader(
Array(udfNoAJar),
stubClassLoader,
sysClassLoader
)
// Load udf with A used in the same class.
deserializeUdf(sessionClassLoader)
// Class A should be stubbed.
assert(stubClassLoader.lastStubbed === "org.apache.spark.sql.connect.artifact.A")
}

test("unload stub class") {
// See src/test/resources/StubClassDummyUdf for how the udf and jar is created.
val sysClassLoader = getClass.getClassLoader()
val stubClassLoader = new RecordedStubClassLoader(null, _ => true)

val cl1 = new ChildFirstURLClassLoader(
Array.empty,
stubClassLoader,
sysClassLoader)

// Failed to load dummy udf
intercept[Exception]{
deserializeUdf(cl1)
}
// Successfully stubbed the missing class.
assert(stubClassLoader.lastStubbed ===
"org.apache.spark.sql.connect.artifact.StubClassDummyUdf")

// Creating a new class loader will unpack the udf correctly.
val cl2 = new ChildFirstURLClassLoader(
Array(udfNoAJar),
stubClassLoader, // even with the same stub class loader.
sysClassLoader
)
// Should be able to load after the artifact is added
deserializeUdf(cl2)
}

test("throw no such method if trying to access methods on stub class") {
// See src/test/resources/StubClassDummyUdf for how the udf and jar is created.
val sysClassLoader = getClass.getClassLoader()
val stubClassLoader = new RecordedStubClassLoader(null, _ => true)

val sessionClassLoader = new ChildFirstURLClassLoader(
Array.empty,
stubClassLoader,
sysClassLoader)

// Failed to load dummy udf
val exception = intercept[Exception]{
deserializeUdf(sessionClassLoader)
}
// Successfully stubbed the missing class.
assert(stubClassLoader.lastStubbed ===
"org.apache.spark.sql.connect.artifact.StubClassDummyUdf")
// But failed to find the method on the stub class.
val cause = exception.getCause
assert(cause.isInstanceOf[NoSuchMethodException])
assert(
cause.getMessage.contains("org.apache.spark.sql.connect.artifact.StubClassDummyUdf"),
cause.getMessage
)
}

private def deserializeUdf(sessionClassLoader: ClassLoader): UdfPacket = {
Utils.deserialize[UdfPacket](
udfByteArray,
sessionClassLoader
)
}
}

class RecordedStubClassLoader(parent: ClassLoader, shouldStub: String => Boolean)
extends StubClassLoader(parent, shouldStub) {
var lastStubbed: String = _

override def findClass(name: String): Class[_] = {
if (shouldStub(name)) {
lastStubbed = name
}
super.findClass(name)
}
}
Loading