Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions python/paddle/jit/dy2static/origin_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,6 @@ def _attach_origin_info(self, node):
setattr(node, ORIGI_INFO, origin_info)

def _abs_lineno(self, node):
# NOTE(liym27):
# There are differences in ast_node.lineno between PY3.8+ and PY3.8-.
# If the first gast.FunctionDef has decorator, the lineno of gast.FunctionDef is differs.
# 1. < PY3.8
# its lineno equals to the lineno of the first decorator node, which is not right.
# 2. >= PY3.8
# its lineno is the actual lineno, which is right.

return self.lineno_offset + node.lineno

def _abs_col_offset(self, node):
Expand Down
20 changes: 6 additions & 14 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
in_dynamic_or_pir_mode,
in_pir_mode,
name_scope,
use_pir_api,
)
from paddle.regularizer import L2Decay

Expand Down Expand Up @@ -795,19 +794,12 @@ def _create_param_lr(self, param_and_grad):
if param_lr == 1.0:
return self._global_learning_rate()
else:
if not use_pir_api():
with paddle.static.default_main_program()._lr_schedule_guard(
is_with_opt=True
), framework.name_scope(
'scale_with_param_lr'
):
return self._global_learning_rate() * param_lr
else:
# TODO(dev): Currently there has not equivalent of op_role in PIR
# mode, so we simply remove _lr_schedule_guard here, this should
# be fixed in the future.
with framework.name_scope('scale_with_param_lr'):
return self._global_learning_rate() * param_lr
with paddle.static.default_main_program()._lr_schedule_guard(
is_with_opt=True
), framework.name_scope(
'scale_with_param_lr'
):
return self._global_learning_rate() * param_lr
else:
return self._global_learning_rate()

Expand Down
21 changes: 18 additions & 3 deletions python/paddle/pir/program_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.base.wrapped_decorator import signature_safe_contextmanager

from . import Program

_already_patch_program = False
Expand All @@ -25,10 +27,23 @@ def global_seed(self, seed=0):
global_prog_seed = seed
self._seed = global_prog_seed

@signature_safe_contextmanager
def _lr_schedule_guard(self, is_with_opt=False):
# TODO(dev): Currently there has not equivalent of op_role in PIR
# mode, so we simply remove _lr_schedule_guard here, this should
# be fixed in the future.
yield

global global_prog_seed
program_attrs = {
"global_seed": global_seed,
"_seed": global_prog_seed,
"_lr_schedule_guard": _lr_schedule_guard,
}

global _already_patch_program
if not _already_patch_program:
Program.global_seed = global_seed
global global_prog_seed
Program._seed = global_prog_seed
for attr, value in program_attrs.items():
setattr(Program, attr, value)

_already_patch_program = True
122 changes: 90 additions & 32 deletions test/dygraph_to_static/test_break_continue.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase, test_ast_only
from dygraph_to_static_utils import (
Dy2StTestBase,
IrMode,
ToStaticMode,
disable_test_case,
enable_to_static_guard,
test_ast_only,
test_legacy_and_pt_and_pir,
)

import paddle
from paddle import base
from paddle.jit.api import to_static
from paddle.jit.dy2static.utils import Dygraph2StaticException

SEED = 2020
Expand All @@ -36,14 +42,12 @@ def setUp(self):
def test_error(self):
if self.dyfunc:
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
paddle.jit.enable_to_static(True)
self.assertTrue(to_static(self.dyfunc)(self.x))
paddle.base.dygraph.base.global_var._in_to_static_mode_ = False
paddle.jit.enable_to_static(False)
with enable_to_static_guard(True):
self.assertTrue(paddle.jit.to_static(self.dyfunc)(self.x))


def test_continue_in_for(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
for i in range(10):
x += 1
if i > 5:
Expand All @@ -54,7 +58,7 @@ def test_continue_in_for(x):


def test_continue_in_for_at_end(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
for i in range(10):
x += 1
if i > 5:
Expand All @@ -63,7 +67,7 @@ def test_continue_in_for_at_end(x):


def test_continue_in_while(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
i = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=0)
while i < 10:
i += 1
Expand All @@ -75,7 +79,7 @@ def test_continue_in_while(x):


def test_break_in_for(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
for i in range(10):
x += 1
if i > 5:
Expand All @@ -86,7 +90,7 @@ def test_break_in_for(x):


def test_break_in_for_at_end(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
for i in range(10):
x += 1
if i > 5:
Expand All @@ -95,7 +99,7 @@ def test_break_in_for_at_end(x):


def test_break_in_while(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)
i = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=0)
while i < 10:
i += 1
Expand All @@ -107,7 +111,7 @@ def test_break_in_while(x):


def test_break_continue_in_for(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)

for i in range(1, 10, 1):
if i <= 4:
Expand Down Expand Up @@ -137,7 +141,7 @@ def test_break_continue_in_for(x):


def test_for_in_else(x):
x = base.dygraph.to_variable(x)
x = paddle.to_tensor(x)

# Case 1:
if False:
Expand Down Expand Up @@ -168,7 +172,7 @@ def __init__(self):
self.c = 5

foo = Foo()
i = base.dygraph.to_variable(x)
i = paddle.to_tensor(x)
while i < 10:
foo.b = paddle.zeros(shape=[1], dtype='float32')
foo.c = foo.b + foo.a
Expand Down Expand Up @@ -204,32 +208,47 @@ def test_optim_break_in_while(x):
return x


class TestContinueInFor(Dy2StTestBase):
class TestContinueBase(Dy2StTestBase):
def setUp(self):
self.input = np.zeros(1).astype('int64')
self.place = (
base.CUDAPlace(0)
if base.is_compiled_with_cuda()
else base.CPUPlace()
)
self.init_dygraph_func()

def init_dygraph_func(self):
self.dygraph_func = test_continue_in_for
raise NotImplementedError(
"For Continue test should implement init_dygraph_func"
)

def run_dygraph_mode(self):
with base.dygraph.guard():
res = self.dygraph_func(self.input)
return res.numpy()
res = self.dygraph_func(self.input)
return res.numpy()

def run_static_mode(self):
with base.dygraph.guard():
res = to_static(self.dygraph_func)(self.input)
return res.numpy()
res = paddle.jit.to_static(self.dygraph_func)(self.input)
return res.numpy()


class TestContinueInFor(TestContinueBase):
def init_dygraph_func(self):
self.dygraph_func = test_continue_in_for

@test_legacy_and_pt_and_pir
def test_transformed_static_result(self):
self.init_dygraph_func()
dygraph_res = self.run_dygraph_mode()
static_res = self.run_static_mode()
np.testing.assert_allclose(
dygraph_res,
static_res,
rtol=1e-05,
err_msg=f'dygraph res is {dygraph_res}\nstatic_res is {static_res}',
)


# TODO(pir-control-flow): Fix this after we support control-flow in PIR
class TestContinueNotPirBase(TestContinueInFor):
def test_transformed_static_result(self):
self.init_dygraph_func()
dygraph_res = self.run_dygraph_mode()
static_res = self.run_static_mode()
np.testing.assert_allclose(
dygraph_res,
static_res,
Expand All @@ -253,7 +272,7 @@ def init_dygraph_func(self):
self.dygraph_func = test_break_in_for_at_end


class TestBreakContinueInFor(TestContinueInFor):
class TestBreakContinueInFor(TestContinueNotPirBase):
def init_dygraph_func(self):
self.dygraph_func = test_break_continue_in_for

Expand All @@ -263,15 +282,41 @@ def init_dygraph_func(self):
self.dygraph_func = test_for_in_else


class TestContinueInWhile(TestContinueInFor):
class TestContinueInWhile(TestContinueNotPirBase):
def init_dygraph_func(self):
self.dygraph_func = test_continue_in_while

# TODO(dev): Remove this after fix PT Rename issue
@disable_test_case((ToStaticMode.AST, IrMode.PT))
def test_transformed_static_result(self):
self.init_dygraph_func()
dygraph_res = self.run_dygraph_mode()
static_res = self.run_static_mode()
np.testing.assert_allclose(
dygraph_res,
static_res,
rtol=1e-05,
err_msg=f'dygraph res is {dygraph_res}\nstatic_res is {static_res}',
)


class TestBreakInWhile(TestContinueInWhile):
def init_dygraph_func(self):
self.dygraph_func = test_break_in_while

# TODO(dev): Remove this after fix PT Rename issue
@disable_test_case((ToStaticMode.AST, IrMode.PT))
def test_transformed_static_result(self):
self.init_dygraph_func()
dygraph_res = self.run_dygraph_mode()
static_res = self.run_static_mode()
np.testing.assert_allclose(
dygraph_res,
static_res,
rtol=1e-05,
err_msg=f'dygraph res is {dygraph_res}\nstatic_res is {static_res}',
)


class TestWhileLoopClassVar(TestContinueInWhile):
def init_dygraph_func(self):
Expand All @@ -289,6 +334,19 @@ class TestOptimBreakInWhile(TestContinueInWhile):
def init_dygraph_func(self):
self.dygraph_func = test_optim_break_in_while

# TODO(dev): Remove this after fix PT Rename issue
@disable_test_case((ToStaticMode.AST, IrMode.PT))
def test_transformed_static_result(self):
self.init_dygraph_func()
dygraph_res = self.run_dygraph_mode()
static_res = self.run_static_mode()
np.testing.assert_allclose(
dygraph_res,
static_res,
rtol=1e-05,
err_msg=f'dygraph res is {dygraph_res}\nstatic_res is {static_res}',
)


if __name__ == '__main__':
unittest.main()
16 changes: 10 additions & 6 deletions test/dygraph_to_static/test_cpu_cuda_to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pt_and_pir,
)

import paddle


class TestCpuCuda(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_cpu_cuda(self):
def func(x):
x = paddle.to_tensor([1, 2, 3, 4])
Expand All @@ -33,10 +34,13 @@ def func(x):

x = paddle.to_tensor([3])
# print(paddle.jit.to_static(func).code)
# print(paddle.jit.to_static(func)(x))
if paddle.is_compiled_with_cuda():
res = paddle.jit.to_static(func)(x)
self.assertTrue(res.place.is_cpu_place())


class TestToTensor(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_to_tensor_with_variable_list(self):
def func(x):
ones = paddle.to_tensor(1)
Expand All @@ -54,7 +58,7 @@ def func(x):


class TestToTensor1(Dy2StTestBase):
@test_ast_only
@test_legacy_and_pt_and_pir
def test_to_tensor_with_variable_list(self):
def func(x):
ones = paddle.to_tensor([1])
Expand All @@ -72,7 +76,7 @@ def func(x):
rtol=1e-05,
)

@test_ast_only
@test_legacy_and_pt_and_pir
def test_to_tensor_with_variable_list_sot(self):
def func(x):
ones = paddle.to_tensor([1])
Expand All @@ -92,7 +96,7 @@ def func(x):


class TestToTensor2(Dy2StTestBase):
@test_ast_only
@test_legacy_and_pt_and_pir
def test_to_tensor_with_variable_list(self):
def func(x):
x = paddle.to_tensor([[1], [2], [3], [4]])
Expand All @@ -105,7 +109,7 @@ def func(x):
rtol=1e-05,
)

@test_ast_only
@test_legacy_and_pt_and_pir
def test_to_tensor_with_variable_list_sot(self):
def func(x):
x = paddle.to_tensor([[1], [2], [3], [4]])
Expand Down
Loading