Skip to content

Conversation

@zhiqiu
Copy link
Contributor

@zhiqiu zhiqiu commented Sep 7, 2021

PR types

New features

PR changes

Others

Describe

change metaclass of Layer from pybind11_builtins.pybind11_type to type

By default, the mateclass of Layer is pybind11_builtins.pybind11_type, and the mateclass of pybind11_builtins.pybind11_type is type.
This PR changes mateclass of Layer to type directly to fix some problem.

For example,

import paddle
import cloudpickle
import os
import subprocess
import torch

class PaddleLayer(paddle.nn.Layer):
    def __init__(self):
        super(PaddleLayer, self).__init__()
        self.linear_1 = paddle.nn.Linear(784, 512)

class TorchModule(torch.nn.Module):
    def __init__(self):
        super(TorchModule, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 5)

class OtherClass(object):
    def __init__(self):
        pass

def test_cloudpickle(cls):
    def to_str(byte):
        """ convert byte to string in pytohn2/3
        """
        return str(byte.decode())

    encoded = cloudpickle.dumps(cls)
    decoded_cls = cloudpickle.loads(encoded) # can be deserialized in the same process
    obj = decoded_cls()

    fname = "{}.cloudpickle".format(cls.__name__)
    with open(fname, 'wb') as f:
        f.write(encoded)

    command = """python -c '
import cloudpickle
with open("{}", "rb") as f:
    encoded = f.read()
    decoded_cls = cloudpickle.loads(encoded)
    obj = decoded_cls()
'
""".format(fname)
    # cannot be deserialized in another process
    try:
        subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)
        print("load {} ok".format(fname))
    except subprocess.CalledProcessError as e:
        print("load {} failed".format(fname))
        print(str(e.output.decode()))

test_cloudpickle(OtherClass)
test_cloudpickle(TorchModule)
test_cloudpickle(PaddleLayer)
  • before
load OtherClass.cloudpickle ok
load TorchModule.cloudpickle ok
load PaddleLayer.cloudpickle failed
Traceback (most recent call last):
  File "<string>", line 5, in <module>
  File "/usr/local/lib/python3.7/dist-packages/cloudpickle/cloudpickle.py", line 736, in _make_skeleton_class
    lambda ns: ns.update(type_kwargs)
  File "/usr/lib/python3.7/types.py", line 65, in new_class
    meta, ns, kwds = prepare_class(name, resolved_bases, kwds)
  File "/usr/lib/python3.7/types.py", line 118, in prepare_class
    meta = _calculate_meta(meta, bases)
  File "/usr/lib/python3.7/types.py", line 136, in _calculate_meta
    raise TypeError("metaclass conflict: "
TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases
  • after
load OtherClass.cloudpickle ok
load TorchModule.cloudpickle ok
load PaddleLayer.cloudpickle ok

@paddle-bot-old
Copy link

paddle-bot-old bot commented Sep 7, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiqiu zhiqiu merged commit 523f46f into PaddlePaddle:develop Sep 10, 2021
AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Sep 29, 2021
PaddlePaddle#35538)

* change metaclass of Layer from pybind11_builtins.pybind11_type to type

* fix cast

* add ut
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants