From 50c3c56cf9fe922f448d9d634408d94e80c3da16 Mon Sep 17 00:00:00 2001 From: William Hicks Date: Tue, 22 Oct 2024 13:29:00 -0700 Subject: [PATCH] Remove type from pickle header for CumlArray --- python/cuml/cuml/internals/array.py | 7 ++----- python/cuml/cuml/tests/test_array.py | 2 -- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/python/cuml/cuml/internals/array.py b/python/cuml/cuml/internals/array.py index c30d609563..71918f77f5 100644 --- a/python/cuml/cuml/internals/array.py +++ b/python/cuml/cuml/internals/array.py @@ -728,9 +728,8 @@ def host_serialize(self): @classmethod def host_deserialize(cls, header, frames): - typ = pickle.loads(header["type-serialized"]) assert all(not is_cuda for is_cuda in header["is-cuda"]) - obj = typ.deserialize(header, frames) + obj = cls.deserialize(header, frames) return obj @nvtx_annotate( @@ -748,9 +747,8 @@ def device_serialize(self): @classmethod def device_deserialize(cls, header, frames): - typ = pickle.loads(header["type-serialized"]) assert all(is_cuda for is_cuda in header["is-cuda"]) - obj = typ.deserialize(header, frames) + obj = cls.deserialize(header, frames) return obj @nvtx_annotate( @@ -761,7 +759,6 @@ def device_deserialize(cls, header, frames): def serialize(self, mem_type=None) -> Tuple[dict, list]: mem_type = self.mem_type if mem_type is None else mem_type header = { - "type-serialized": pickle.dumps(type(self)), "constructor-kwargs": { "dtype": self.dtype.str, "shape": self.shape, diff --git a/python/cuml/cuml/tests/test_array.py b/python/cuml/cuml/tests/test_array.py index f64683717d..798579f28a 100644 --- a/python/cuml/cuml/tests/test_array.py +++ b/python/cuml/cuml/tests/test_array.py @@ -535,8 +535,6 @@ def test_serialize(inp, to_serialize_mem_type, from_serialize_mem_type): with using_memory_type(from_serialize_mem_type): ary2 = CumlArray.deserialize(header, frames) - assert pickle.loads(header["type-serialized"]) is CumlArray - _assert_equal(inp, ary2) assert ary._array_interface["shape"] == ary2._array_interface["shape"]