Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/paddle/jit/dy2static/transformers/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def transfer_from_node_type(self, node):
self.visit(node)

transformers = [
TypeHintTransformer, # remove all typehint
RegisterHookTransformer,
EarlyReturnTransformer,
AttributeJstTransformer, # Tensor.size -> Tensor.size(), it's unnecessary in PIR mode
Expand All @@ -107,7 +108,6 @@ def transfer_from_node_type(self, node):
CastTransformer, # type casting statement
DecoratorTransformer, # transform decorators to function call
NameloadJstTransformer,
TypeHintTransformer, # remove all typehint in gast.Name
]

apply_optimization(transformers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.


from paddle.utils import gast

from .base import BaseTransformer

__all__ = []
Expand All @@ -39,3 +41,9 @@ def visit_Name(self, node):
node.annotation = None
self.generic_visit(node)
return node

def visit_AnnAssign(self, node):
if node.value is None:
return None
assign_node = gast.Assign(targets=[node.target], value=node.value)
return assign_node
50 changes: 40 additions & 10 deletions test/dygraph_to_static/test_typehint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import unittest
from typing import List

import numpy as np
from dygraph_to_static_utils import (
Expand All @@ -22,9 +23,6 @@

import paddle

SEED = 2020
np.random.seed(SEED)


class A:
pass
Expand All @@ -35,13 +33,25 @@ def function(x: A) -> A:
return 2 * x


class TestTypeHint(Dy2StTestBase):
def fn_annotation_assign_with_value(x: paddle.Tensor):
if x:
y: List["paddle.Tensor"] = [x + 1]
else:
y: List["paddle.Tensor"] = [x - 1]
return y


def fn_annotation_assign_without_value(x: paddle.Tensor):
if x:
y: List["paddle.Tensor"]
y = [x + 1]
else:
y = [x - 1]
return y


class TestTypeHints(Dy2StTestBase):
def setUp(self):
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.x = np.zeros(shape=(1), dtype=np.int32)
self._init_dyfunc()

Expand Down Expand Up @@ -70,9 +80,29 @@ def _run(self, to_static):
def test_ast_to_func(self):
static_numpy = self._run_static()
dygraph_numpy = self._run_dygraph()
print(static_numpy, dygraph_numpy)
np.testing.assert_allclose(dygraph_numpy, static_numpy, rtol=1e-05)


class TestAnnAssign(Dy2StTestBase):
def assert_fn_dygraph_and_static_unified(self, dygraph_fn, x):
static_fn = paddle.jit.to_static(fn_annotation_assign_with_value)
dygraph_fn = dygraph_fn
static_res = static_fn(x)
dygraph_res = dygraph_fn(x)
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)

@test_legacy_and_pt_and_pir
def test_ann_assign_with_value(self):
self.assert_fn_dygraph_and_static_unified(
fn_annotation_assign_with_value, paddle.to_tensor(1)
)

@test_legacy_and_pt_and_pir
def test_ann_assign_without_value(self):
self.assert_fn_dygraph_and_static_unified(
fn_annotation_assign_without_value, paddle.to_tensor(1)
)


if __name__ == '__main__':
unittest.main()