@@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf)
6161 val instantiator = new EmptyScalaKryoInstantiator
6262 val kryo = instantiator.newKryo()
6363 kryo.setRegistrationRequired(registrationRequired)
64- val classLoader = Thread .currentThread.getContextClassLoader
64+
65+ val oldClassLoader = Thread .currentThread.getContextClassLoader
66+ val classLoader = defaultClassLoader.getOrElse(Thread .currentThread.getContextClassLoader)
6567
6668 // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
6769 // Do this before we invoke the user registrator so the user registrator can override this.
@@ -84,10 +86,15 @@ class KryoSerializer(conf: SparkConf)
8486 try {
8587 val reg = Class .forName(regCls, true , classLoader).newInstance()
8688 .asInstanceOf [KryoRegistrator ]
89+
90+ // Use the default classloader when calling the user registrator.
91+ Thread .currentThread.setContextClassLoader(classLoader)
8792 reg.registerClasses(kryo)
8893 } catch {
89- case e : Exception =>
94+ case e : Exception =>
9095 throw new SparkException (s " Failed to invoke $regCls" , e)
96+ } finally {
97+ Thread .currentThread.setContextClassLoader(oldClassLoader)
9198 }
9299 }
93100
@@ -99,7 +106,7 @@ class KryoSerializer(conf: SparkConf)
99106 kryo
100107 }
101108
102- def newInstance (): SerializerInstance = {
109+ override def newInstance (): SerializerInstance = {
103110 new KryoSerializerInstance (this )
104111 }
105112}
@@ -108,20 +115,20 @@ private[spark]
108115class KryoSerializationStream (kryo : Kryo , outStream : OutputStream ) extends SerializationStream {
109116 val output = new KryoOutput (outStream)
110117
111- def writeObject [T : ClassTag ](t : T ): SerializationStream = {
118+ override def writeObject [T : ClassTag ](t : T ): SerializationStream = {
112119 kryo.writeClassAndObject(output, t)
113120 this
114121 }
115122
116- def flush () { output.flush() }
117- def close () { output.close() }
123+ override def flush () { output.flush() }
124+ override def close () { output.close() }
118125}
119126
120127private [spark]
121128class KryoDeserializationStream (kryo : Kryo , inStream : InputStream ) extends DeserializationStream {
122- val input = new KryoInput (inStream)
129+ private val input = new KryoInput (inStream)
123130
124- def readObject [T : ClassTag ](): T = {
131+ override def readObject [T : ClassTag ](): T = {
125132 try {
126133 kryo.readClassAndObject(input).asInstanceOf [T ]
127134 } catch {
@@ -131,31 +138,31 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser
131138 }
132139 }
133140
134- def close () {
141+ override def close () {
135142 // Kryo's Input automatically closes the input stream it is using.
136143 input.close()
137144 }
138145}
139146
140147private [spark] class KryoSerializerInstance (ks : KryoSerializer ) extends SerializerInstance {
141- val kryo = ks.newKryo()
148+ private val kryo = ks.newKryo()
142149
143150 // Make these lazy vals to avoid creating a buffer unless we use them
144- lazy val output = ks.newKryoOutput()
145- lazy val input = new KryoInput ()
151+ private lazy val output = ks.newKryoOutput()
152+ private lazy val input = new KryoInput ()
146153
147- def serialize [T : ClassTag ](t : T ): ByteBuffer = {
154+ override def serialize [T : ClassTag ](t : T ): ByteBuffer = {
148155 output.clear()
149156 kryo.writeClassAndObject(output, t)
150157 ByteBuffer .wrap(output.toBytes)
151158 }
152159
153- def deserialize [T : ClassTag ](bytes : ByteBuffer ): T = {
160+ override def deserialize [T : ClassTag ](bytes : ByteBuffer ): T = {
154161 input.setBuffer(bytes.array)
155162 kryo.readClassAndObject(input).asInstanceOf [T ]
156163 }
157164
158- def deserialize [T : ClassTag ](bytes : ByteBuffer , loader : ClassLoader ): T = {
165+ override def deserialize [T : ClassTag ](bytes : ByteBuffer , loader : ClassLoader ): T = {
159166 val oldClassLoader = kryo.getClassLoader
160167 kryo.setClassLoader(loader)
161168 input.setBuffer(bytes.array)
@@ -164,11 +171,11 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
164171 obj
165172 }
166173
167- def serializeStream (s : OutputStream ): SerializationStream = {
174+ override def serializeStream (s : OutputStream ): SerializationStream = {
168175 new KryoSerializationStream (kryo, s)
169176 }
170177
171- def deserializeStream (s : InputStream ): DeserializationStream = {
178+ override def deserializeStream (s : InputStream ): DeserializationStream = {
172179 new KryoDeserializationStream (kryo, s)
173180 }
174181}
0 commit comments