Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/jit/dy2static/transformers/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def transfer_from_node_type(self, node):
self.visit(node)

transformers = [
TypeHintTransformer, # remove all typehint
RegisterHookTransformer,
EarlyReturnTransformer,
AttributeJstTransformer, # Tensor.size -> Tensor.size(), it's unnecessary in PIR mode
Expand All @@ -107,7 +108,6 @@ def transfer_from_node_type(self, node):
CastTransformer, # type casting statement
DecoratorTransformer, # transform decorators to function call
NameloadJstTransformer,
TypeHintTransformer, # remove all typehint in gast.Name
]

apply_optimization(transformers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


from paddle.utils import gast

from .base import BaseTransformer

__all__ = []
Expand All @@ -39,3 +41,9 @@ def visit_Name(self, node):
node.annotation = None
self.generic_visit(node)
return node

def visit_AnnAssign(self, node):
if node.value is None:
return None
assign_node = gast.Assign(targets=[node.target], value=node.value)
return assign_node
50 changes: 40 additions & 10 deletions test/dygraph_to_static/test_typehint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import unittest
from typing import List

import numpy as np
from dygraph_to_static_utils import (
Expand All @@ -22,9 +23,6 @@

import paddle

SEED = 2020
np.random.seed(SEED)


class A:
pass
Expand All @@ -35,13 +33,25 @@ def function(x: A) -> A:
return 2 * x


class TestTypeHint(Dy2StTestBase):
def fn_annotation_assign_with_value(x: paddle.Tensor):
if x:
y: List["paddle.Tensor"] = [x + 1]
else:
y: List["paddle.Tensor"] = [x - 1]
return y


def fn_annotation_assign_without_value(x: paddle.Tensor):
if x:
y: List["paddle.Tensor"]
y = [x + 1]
else:
y = [x - 1]
return y


class TestTypeHints(Dy2StTestBase):
def setUp(self):
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.x = np.zeros(shape=(1), dtype=np.int32)
self._init_dyfunc()

Expand Down Expand Up @@ -70,9 +80,29 @@ def _run(self, to_static):
def test_ast_to_func(self):
static_numpy = self._run_static()
dygraph_numpy = self._run_dygraph()
print(static_numpy, dygraph_numpy)
np.testing.assert_allclose(dygraph_numpy, static_numpy, rtol=1e-05)


class TestAnnAssign(Dy2StTestBase):
def assert_fn_dygraph_and_static_unified(self, dygraph_fn, x):
static_fn = paddle.jit.to_static(dygraph_fn)
dygraph_fn = dygraph_fn
static_res = static_fn(x)
dygraph_res = dygraph_fn(x)
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)

@test_legacy_and_pt_and_pir
def test_ann_assign_with_value(self):
self.assert_fn_dygraph_and_static_unified(
fn_annotation_assign_with_value, paddle.to_tensor(1)
)

@test_legacy_and_pt_and_pir
def test_ann_assign_without_value(self):
self.assert_fn_dygraph_and_static_unified(
fn_annotation_assign_without_value, paddle.to_tensor(1)
)


if __name__ == '__main__':
unittest.main()