diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 29358b5740e51..84c7e80458997 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -24,7 +24,7 @@ from pyspark.cloudpickle import print_exec from pyspark.java_gateway import local_connect_and_auth -from pyspark.serializers import ChunkedStream +from pyspark.serializers import ChunkedStream, pickle_protocol from pyspark.util import _exception_message if sys.version < '3': @@ -110,7 +110,7 @@ def __init__(self, sc=None, value=None, pickle_registry=None, path=None, def dump(self, value, f): try: - pickle.dump(value, f, 2) + pickle.dump(value, f, pickle_protocol) except pickle.PickleError: raise except Exception as e: diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 88519d7311fcc..bf92569c1e8c0 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -42,20 +42,26 @@ """ from __future__ import print_function -import dis -from functools import partial -import imp import io -import itertools -import logging +import dis +import sys +import types import opcode -import operator import pickle import struct -import sys -import traceback -import types +import logging import weakref +import operator +import importlib +import itertools +import traceback +from functools import partial + + +# cloudpickle is meant for inter process communication: we expect all +# communicating processes to run the same Python version hence we favor +# communication speed over compatibility: +DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL if sys.version < '3': @@ -72,6 +78,22 @@ PY3 = True +# Container for the global namespace to ensure consistent unpickling of +# functions defined in dynamic modules (modules not registed in sys.modules). +_dynamic_modules_globals = weakref.WeakValueDictionary() + + +class _DynamicModuleFuncGlobals(dict): + """Global variables referenced by a function defined in a dynamic module + + To avoid leaking references we store such context in a WeakValueDictionary + instance. However instances of python builtin types such as dict cannot + be used directly as values in such a construct, hence the need for a + derived class. + """ + pass + + def _make_cell_set_template_code(): """Get the Python compiler to emit LOAD_FAST(arg); STORE_DEREF @@ -157,7 +179,7 @@ def cell_set(cell, value): )(value) -#relevant opcodes +# relevant opcodes STORE_GLOBAL = opcode.opmap['STORE_GLOBAL'] DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL'] LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL'] @@ -167,7 +189,7 @@ def cell_set(cell, value): def islambda(func): - return getattr(func,'__name__') == '' + return getattr(func, '__name__') == '' _BUILTIN_TYPE_NAMES = {} @@ -248,7 +270,9 @@ class CloudPickler(Pickler): dispatch = Pickler.dispatch.copy() def __init__(self, file, protocol=None): - Pickler.__init__(self, file, protocol) + if protocol is None: + protocol = DEFAULT_PROTOCOL + Pickler.__init__(self, file, protocol=protocol) # set of modules to unpickle self.modules = set() # map ids to dictionary. used to ensure that functions can share global env @@ -267,42 +291,26 @@ def dump(self, obj): def save_memoryview(self, obj): self.save(obj.tobytes()) + dispatch[memoryview] = save_memoryview if not PY3: def save_buffer(self, obj): self.save(str(obj)) - dispatch[buffer] = save_buffer # noqa: F821 'buffer' was removed in Python 3 - - def save_unsupported(self, obj): - raise pickle.PicklingError("Cannot pickle objects of type %s" % type(obj)) - dispatch[types.GeneratorType] = save_unsupported - # itertools objects do not pickle! - for v in itertools.__dict__.values(): - if type(v) is type: - dispatch[v] = save_unsupported + dispatch[buffer] = save_buffer # noqa: F821 'buffer' was removed in Python 3 def save_module(self, obj): """ Save a module as an import """ - mod_name = obj.__name__ - # If module is successfully found then it is not a dynamically created module - if hasattr(obj, '__file__'): - is_dynamic = False - else: - try: - _find_module(mod_name) - is_dynamic = False - except ImportError: - is_dynamic = True - self.modules.add(obj) - if is_dynamic: - self.save_reduce(dynamic_subimport, (obj.__name__, vars(obj)), obj=obj) + if _is_dynamic(obj): + self.save_reduce(dynamic_subimport, (obj.__name__, vars(obj)), + obj=obj) else: self.save_reduce(subimport, (obj.__name__,), obj=obj) + dispatch[types.ModuleType] = save_module def save_codeobject(self, obj): @@ -323,6 +331,7 @@ def save_codeobject(self, obj): obj.co_firstlineno, obj.co_lnotab, obj.co_freevars, obj.co_cellvars ) self.save_reduce(types.CodeType, args, obj=obj) + dispatch[types.CodeType] = save_codeobject def save_function(self, obj, name=None): @@ -369,9 +378,14 @@ def save_function(self, obj, name=None): if modname == '__main__': themodule = None + try: + lookedup_by_name = getattr(themodule, name, None) + except Exception: + lookedup_by_name = None + if themodule: self.modules.add(themodule) - if getattr(themodule, name, None) is obj: + if lookedup_by_name is obj: return self.save_global(obj, name) # a builtin_function_or_method which comes in as an attribute of some @@ -401,8 +415,7 @@ def save_function(self, obj, name=None): return else: # func is nested - klass = getattr(themodule, name, None) - if klass is None or klass is not obj: + if lookedup_by_name is None or lookedup_by_name is not obj: self.save_function_tuple(obj) return @@ -416,6 +429,7 @@ def save_function(self, obj, name=None): else: write(pickle.GLOBAL + modname + '\n' + name + '\n') self.memoize(obj) + dispatch[types.FunctionType] = save_function def _save_subimports(self, code, top_level_dependencies): @@ -423,19 +437,22 @@ def _save_subimports(self, code, top_level_dependencies): Ensure de-pickler imports any package child-modules that are needed by the function """ + # check if any known dependency is an imported package for x in top_level_dependencies: if isinstance(x, types.ModuleType) and hasattr(x, '__package__') and x.__package__: # check if the package has any currently loaded sub-imports prefix = x.__name__ + '.' - for name, module in sys.modules.items(): + # A concurrent thread could mutate sys.modules, + # make sure we iterate over a copy to avoid exceptions + for name in list(sys.modules): # Older versions of pytest will add a "None" module to sys.modules. if name is not None and name.startswith(prefix): # check whether the function can address the sub-module tokens = set(name[len(prefix):].split('.')) if not tokens - set(code.co_names): # ensure unpickler executes this import - self.save(module) + self.save(sys.modules[name]) # then discards the reference to it self.write(pickle.POP) @@ -450,6 +467,15 @@ def save_dynamic_class(self, obj): clsdict = dict(obj.__dict__) # copy dict proxy to a dict clsdict.pop('__weakref__', None) + # For ABCMeta in python3.7+, remove _abc_impl as it is not picklable. + # This is a fix which breaks the cache but this only makes the first + # calls to issubclass slower. + if "_abc_impl" in clsdict: + import abc + (registry, _, _, _) = abc._get_dump(obj) + clsdict["_abc_impl"] = [subclass_weakref() + for subclass_weakref in registry] + # On PyPy, __doc__ is a readonly attribute, so we need to include it in # the initial skeleton class. This is safe because we know that the # doc can't participate in a cycle with the original class. @@ -541,9 +567,13 @@ def save_function_tuple(self, func): 'globals': f_globals, 'defaults': defaults, 'dict': dct, - 'module': func.__module__, 'closure_values': closure_values, + 'module': func.__module__, + 'name': func.__name__, + 'doc': func.__doc__, } + if hasattr(func, '__annotations__') and sys.version_info >= (3, 7): + state['annotations'] = func.__annotations__ if hasattr(func, '__qualname__'): state['qualname'] = func.__qualname__ save(state) @@ -568,8 +598,7 @@ def extract_code_globals(cls, co): # PyPy "builtin-code" object out_names = set() else: - out_names = set(names[oparg] - for op, oparg in _walk_global_ops(co)) + out_names = {names[oparg] for _, oparg in _walk_global_ops(co)} # see if nested function have any global refs if co.co_consts: @@ -610,7 +639,16 @@ def extract_func_data(self, func): # save the dict dct = func.__dict__ - base_globals = self.globals_ref.get(id(func.__globals__), {}) + base_globals = self.globals_ref.get(id(func.__globals__), None) + if base_globals is None: + # For functions defined in a well behaved module use + # vars(func.__module__) for base_globals. This is necessary to + # share the global variables across multiple pickled functions from + # this module. + if hasattr(func, '__module__') and func.__module__ is not None: + base_globals = func.__module__ + else: + base_globals = {} self.globals_ref[id(func.__globals__)] = base_globals return (code, f_globals, defaults, closure, dct, base_globals) @@ -619,6 +657,7 @@ def save_builtin_function(self, obj): if obj.__module__ == "__builtin__": return self.save_global(obj) return self.save_function(obj) + dispatch[types.BuiltinFunctionType] = save_builtin_function def save_global(self, obj, name=None, pack=struct.pack): @@ -628,6 +667,13 @@ def save_global(self, obj, name=None, pack=struct.pack): The name of this method is somewhat misleading: all types get dispatched here. """ + if obj is type(None): + return self.save_reduce(type, (None,), obj=obj) + elif obj is type(Ellipsis): + return self.save_reduce(type, (Ellipsis,), obj=obj) + elif obj is type(NotImplemented): + return self.save_reduce(type, (NotImplemented,), obj=obj) + if obj.__module__ == "__main__": return self.save_dynamic_class(obj) @@ -657,7 +703,8 @@ def save_instancemethod(self, obj): self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj) else: self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__), - obj=obj) + obj=obj) + dispatch[types.MethodType] = save_instancemethod def save_inst(self, obj): @@ -711,11 +758,13 @@ def save_inst(self, obj): def save_property(self, obj): # properties not correctly saved in python self.save_reduce(property, (obj.fget, obj.fset, obj.fdel, obj.__doc__), obj=obj) + dispatch[property] = save_property def save_classmethod(self, obj): orig_func = obj.__func__ self.save_reduce(type(obj), (orig_func,), obj=obj) + dispatch[classmethod] = save_classmethod dispatch[staticmethod] = save_classmethod @@ -726,7 +775,7 @@ def __getitem__(self, item): return item items = obj(Dummy()) if not isinstance(items, tuple): - items = (items, ) + items = (items,) return self.save_reduce(operator.itemgetter, items) if type(operator.itemgetter) is type: @@ -757,16 +806,16 @@ def __getattribute__(self, item): def save_file(self, obj): """Save a file""" try: - import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute + import StringIO as pystringIO # we can't use cStringIO as it lacks the name attribute except ImportError: import io as pystringIO - if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): + if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): raise pickle.PicklingError("Cannot pickle files that do not map to an actual file") if obj is sys.stdout: - return self.save_reduce(getattr, (sys,'stdout'), obj=obj) + return self.save_reduce(getattr, (sys, 'stdout'), obj=obj) if obj is sys.stderr: - return self.save_reduce(getattr, (sys,'stderr'), obj=obj) + return self.save_reduce(getattr, (sys, 'stderr'), obj=obj) if obj is sys.stdin: raise pickle.PicklingError("Cannot pickle standard input") if obj.closed: @@ -845,6 +894,7 @@ def is_tornado_coroutine(func): return False return gen.is_coroutine_function(func) + def _rebuild_tornado_coroutine(func): from tornado import gen return gen.coroutine(func) @@ -852,36 +902,55 @@ def _rebuild_tornado_coroutine(func): # Shorthands for legacy support -def dump(obj, file, protocol=2): - CloudPickler(file, protocol).dump(obj) +def dump(obj, file, protocol=None): + """Serialize obj as bytes streamed into file + protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to + pickle.HIGHEST_PROTOCOL. This setting favors maximum communication speed + between processes running the same Python version. + + Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure + compatibility with older versions of Python. + """ + CloudPickler(file, protocol=protocol).dump(obj) -def dumps(obj, protocol=2): + +def dumps(obj, protocol=None): + """Serialize obj as a string of bytes allocated in memory + + protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to + pickle.HIGHEST_PROTOCOL. This setting favors maximum communication speed + between processes running the same Python version. + + Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure + compatibility with older versions of Python. + """ file = StringIO() try: - cp = CloudPickler(file,protocol) + cp = CloudPickler(file, protocol=protocol) cp.dump(obj) return file.getvalue() finally: file.close() + # including pickles unloading functions in this namespace load = pickle.load loads = pickle.loads -#hack for __import__ not working as desired +# hack for __import__ not working as desired def subimport(name): __import__(name) return sys.modules[name] def dynamic_subimport(name, vars): - mod = imp.new_module(name) + mod = types.ModuleType(name) mod.__dict__.update(vars) - sys.modules[name] = mod return mod + # restores function attributes def _restore_attr(obj, attr): for key, val in attr.items(): @@ -908,7 +977,7 @@ def _modules_to_main(modList): if type(modname) is str: try: mod = __import__(modname) - except Exception as e: + except Exception: sys.stderr.write('warning: could not import %s\n. ' 'Your function may unexpectedly error due to this import failing;' 'A version mismatch is likely. Specific error was:\n' % modname) @@ -917,7 +986,7 @@ def _modules_to_main(modList): setattr(main, mod.__name__, mod) -#object generators: +# object generators: def _genpartial(func, args, kwds): if not args: args = () @@ -925,9 +994,11 @@ def _genpartial(func, args, kwds): kwds = {} return partial(func, *args, **kwds) + def _gen_ellipsis(): return Ellipsis + def _gen_not_implemented(): return NotImplemented @@ -988,9 +1059,19 @@ def _fill_function(*args): else: raise ValueError('Unexpected _fill_value arguments: %r' % (args,)) - func.__globals__.update(state['globals']) + # Only set global variables that do not exist. + for k, v in state['globals'].items(): + if k not in func.__globals__: + func.__globals__[k] = v + func.__defaults__ = state['defaults'] func.__dict__ = state['dict'] + if 'annotations' in state: + func.__annotations__ = state['annotations'] + if 'doc' in state: + func.__doc__ = state['doc'] + if 'name' in state: + func.__name__ = state['name'] if 'module' in state: func.__module__ = state['module'] if 'qualname' in state: @@ -1021,6 +1102,20 @@ def _make_skel_func(code, cell_count, base_globals=None): """ if base_globals is None: base_globals = {} + elif isinstance(base_globals, str): + base_globals_name = base_globals + try: + # First try to reuse the globals from the module containing the + # function. If it is not possible to retrieve it, fallback to an + # empty dictionary. + base_globals = vars(importlib.import_module(base_globals)) + except ImportError: + base_globals = _dynamic_modules_globals.get( + base_globals_name, None) + if base_globals is None: + base_globals = _DynamicModuleFuncGlobals() + _dynamic_modules_globals[base_globals_name] = base_globals + base_globals['__builtins__'] = __builtins__ closure = ( @@ -1036,28 +1131,50 @@ def _rehydrate_skeleton_class(skeleton_class, class_dict): See CloudPickler.save_dynamic_class for more info. """ + registry = None for attrname, attr in class_dict.items(): - setattr(skeleton_class, attrname, attr) + if attrname == "_abc_impl": + registry = attr + else: + setattr(skeleton_class, attrname, attr) + if registry is not None: + for subclass in registry: + skeleton_class.register(subclass) + return skeleton_class -def _find_module(mod_name): +def _is_dynamic(module): """ - Iterate over each part instead of calling imp.find_module directly. - This function is able to find submodules (e.g. sickit.tree) + Return True if the module is special module that cannot be imported by its + name. """ - path = None - for part in mod_name.split('.'): - if path is not None: - path = [path] - file, path, description = imp.find_module(part, path) - if file is not None: - file.close() - return path, description + # Quick check: module that have __file__ attribute are not dynamic modules. + if hasattr(module, '__file__'): + return False + + if hasattr(module, '__spec__'): + return module.__spec__ is None + else: + # Backward compat for Python 2 + import imp + try: + path = None + for part in module.__name__.split('.'): + if path is not None: + path = [path] + f, path, description = imp.find_module(part, path) + if f is not None: + f.close() + except ImportError: + return True + return False + """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" + def _getobject(modname, attribute): mod = __import__(modname, fromlist=[attribute]) return mod.__dict__[attribute] diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fd4695210fb7c..ba8c0ce4bf8bc 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -61,12 +61,11 @@ if sys.version < '3': import cPickle as pickle - protocol = 2 from itertools import izip as zip, imap as map else: import pickle - protocol = 3 xrange = range +pickle_protocol = pickle.HIGHEST_PROTOCOL from pyspark import cloudpickle from pyspark.util import _exception_message @@ -606,7 +605,7 @@ class PickleSerializer(FramedSerializer): """ def dumps(self, obj): - return pickle.dumps(obj, protocol) + return pickle.dumps(obj, pickle_protocol) if sys.version >= '3': def loads(self, obj, encoding="bytes"): @@ -620,7 +619,7 @@ class CloudPickleSerializer(PickleSerializer): def dumps(self, obj): try: - return cloudpickle.dumps(obj, 2) + return cloudpickle.dumps(obj, pickle_protocol) except pickle.PickleError: raise except Exception as e: diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index b2a544b8de78a..50d88a59daf2c 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -605,7 +605,7 @@ def test_distinct(self): def test_external_group_by_key(self): self.sc._conf.set("spark.python.worker.memory", "1m") - N = 200001 + N = 2000001 kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x)) gkv = kv.groupByKey().cache() self.assertEqual(3, gkv.count())