diff --git a/python/paddle/static/io.py b/python/paddle/static/io.py index 3d3d4f30fa2d47..f4b61001a9fb6a 100644 --- a/python/paddle/static/io.py +++ b/python/paddle/static/io.py @@ -142,6 +142,11 @@ def _clone_var_in_block(block, var): ) +def _safe_load_pickle(file, encoding="ASCII"): + load_dict = pickle.Unpickler(file, encoding=encoding).load() + return load_dict + + def prepend_feed_ops( inference_program, feed_target_names, feed_holder_name='feed' ): @@ -1697,7 +1702,7 @@ def set_var(var, ndarray): if sys.platform == 'darwin' and sys.version_info.major == 3: load_dict = _pickle_loads_mac(parameter_file_name, f) else: - load_dict = pickle.load(f, encoding='latin1') + load_dict = _safe_load_pickle(f, encoding='latin1') load_dict = _pack_loaded_dict(load_dict) for v in parameter_list: assert ( @@ -1721,7 +1726,7 @@ def set_var(var, ndarray): ) with open(opt_file_name, 'rb') as f: - load_dict = pickle.load(f, encoding='latin1') + load_dict = _safe_load_pickle(f, encoding='latin1') for v in optimizer_var_list: assert ( v.name in load_dict @@ -2015,13 +2020,13 @@ def _load_vars_with_try_catch( if sys.platform == 'darwin' and sys.version_info.major == 3: para_dict = _pickle_loads_mac(parameter_file_name, f) else: - para_dict = pickle.load(f, encoding='latin1') + para_dict = _safe_load_pickle(f, encoding='latin1') para_dict = _pack_loaded_dict(para_dict) opt_file_name = model_prefix + ".pdopt" if os.path.exists(opt_file_name): with open(opt_file_name, 'rb') as f: - opti_dict = pickle.load(f, encoding='latin1') + opti_dict = _safe_load_pickle(f, encoding='latin1') para_dict.update(opti_dict)