Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions test/deprecated/prim/prim/flags/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ foreach(TEST_OP ${TEST_OPS})
endforeach()

if(WITH_CINN)
set_tests_properties(test_prim_flags_case PROPERTIES LABELS "RUN_TYPE=CINN")
set_tests_properties(test_prim_flags_case PROPERTIES TIMEOUT 300)
set_tests_properties(test_prim_flags_case_deprecated
PROPERTIES LABELS "RUN_TYPE=CINN")
set_tests_properties(test_prim_flags_case_deprecated PROPERTIES TIMEOUT 300)
endif()
6 changes: 0 additions & 6 deletions test/deprecated/prim/prim/vjp/static/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,3 @@ set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS})
endforeach()

set_tests_properties(test_comp_div_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_add_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_sub_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_add_tanh_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_sqrt_grad PROPERTIES TIMEOUT 60)
5 changes: 5 additions & 0 deletions test/prim/prim/vjp/static/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ endforeach()

set_tests_properties(test_comp_sum_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_tanh_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_div_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_add_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_sub_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_add_tanh_grad PROPERTIES TIMEOUT 60)
set_tests_properties(test_comp_sqrt_grad PROPERTIES TIMEOUT 60)
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,9 @@ def train(self, use_prim, use_cinn):

return res

def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=True)

for i in range(len(dy_res)):
np.testing.assert_allclose(
comp_st_cinn_res[i].numpy(),
dy_res[i].numpy(),
rtol=1e-7,
atol=1e-7,
)
def test_tanh_grad_comp(self):
paddle.enable_static()

def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
Expand All @@ -123,7 +111,7 @@ def actual(primal0, primal1):
'primal0': primal0,
'primal1': primal1,
},
fetch_list=[res[0].name, res[1].name],
fetch_list=[res[0], res[1]],
)
return out[0], out[1]

Expand All @@ -149,7 +137,7 @@ def desired(primal0, primal1):
'primal0': self.primal0,
'primal1': self.primal1,
},
fetch_list=[res[0].name, res[1].name],
fetch_list=[res[0], res[1]],
)
return out[0], out[1]

Expand All @@ -170,6 +158,7 @@ def desired(primal0, primal1):
atol=0,
)
core._set_prim_backward_enabled(False)
paddle.disable_static()


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,6 @@ def train(self, use_prim, use_cinn):

return res

def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=True)

for i in range(len(dy_res)):
np.testing.assert_allclose(
comp_st_cinn_res[i].numpy(),
dy_res[i].numpy(),
rtol=1e-7,
atol=1e-7,
)
paddle.enable_static()

def test_tanh_grad_comp(self):
paddle.enable_static()

Expand All @@ -127,7 +113,7 @@ def actual(primal0, primal1):
'primal0': primal0,
'primal1': primal1,
},
fetch_list=[res[0].name, res[1].name],
fetch_list=[res[0], res[1]],
)
return out[0], out[1]

Expand All @@ -154,7 +140,7 @@ def desired(primal0, primal1):
'primal0': self.primal0,
'primal1': self.primal1,
},
fetch_list=[res[0].name, res[1].name],
fetch_list=[res[0], res[1]],
)
return out[0], out[1]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,9 @@ def train(self, use_prim, use_cinn):

return res

def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=True)

for i in range(len(dy_res)):
np.testing.assert_allclose(
comp_st_cinn_res[i].numpy(),
dy_res[i].numpy(),
rtol=1e-6,
atol=1e-6,
)
def test_tanh_grad_comp(self):
paddle.enable_static()

def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
Expand All @@ -123,7 +111,7 @@ def actual(primal0, primal1):
'primal0': primal0,
'primal1': primal1,
},
fetch_list=[res[0].name, res[1].name],
fetch_list=[res[0], res[1]],
)
return out[0], out[1]

Expand All @@ -149,7 +137,7 @@ def desired(primal0, primal1):
'primal0': self.primal0,
'primal1': self.primal1,
},
fetch_list=[res[0].name, res[1].name],
fetch_list=[res[0], res[1]],
)
return out[0], out[1]

Expand All @@ -170,6 +158,7 @@ def desired(primal0, primal1):
atol=0,
)
core._set_prim_backward_enabled(False)
paddle.disable_static()


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def actual(primal, cotangent):
)
y = paddle.exp(x)
x_cotangent = paddle.static.gradients(y, x, v)
if x_cotangent == [None]:
x_cotangent = []
exe = paddle.static.Executor()
exe.run(sp)
return exe.run(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,9 @@ def actual(primal0, primal1, primal2, trans_0, trans_1):
'primal2': primal2,
},
fetch_list=[
res_double[0].name,
res_double[1].name,
res_double[2].name,
res_double[0],
res_double[1],
res_double[2],
],
)

Expand Down Expand Up @@ -271,9 +271,9 @@ def desired(primal0, primal1, primal2, trans_0, trans_1):
'primal2': primal2,
},
fetch_list=[
res_double[0].name,
res_double[1].name,
res_double[2].name,
res_double[0],
res_double[1],
res_double[2],
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def actual(primal, cotangent):
'cotangent': cotangent,
},
fetch_list=[
x_grad[0].name,
x_grad[0],
],
)

Expand Down Expand Up @@ -95,7 +95,7 @@ def desired(primal, cotangent):
'cotangent': cotangent,
},
fetch_list=[
x_grad[0].name,
x_grad[0],
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,21 +69,9 @@ def train(self, use_prim, use_cinn):

return res

def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=True)

for i in range(len(dy_res)):
np.testing.assert_allclose(
comp_st_cinn_res[i].numpy(),
dy_res[i].numpy(),
rtol=1e-7,
atol=1e-7,
)
def test_sqrt_grad_comp(self):
paddle.enable_static()

def test_sqrt_grad_comp(self):
def actual(primal, cotangent):
mp, sp = paddle.static.Program(), paddle.static.Program()
with paddle.static.program_guard(mp, sp):
Expand All @@ -99,7 +87,7 @@ def actual(primal, cotangent):
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=[x_cotangent[0].name],
fetch_list=[x_cotangent[0]],
)[0]

def desired(primal, cotangent):
Expand All @@ -112,6 +100,7 @@ def desired(primal, cotangent):
atol=0,
)
core._set_prim_backward_enabled(False)
paddle.disable_static()


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,9 @@ def train(self, use_prim, use_cinn):

return res

def test_cinn(self):
paddle.disable_static()
dy_res = self.train(use_prim=False, use_cinn=False)
comp_st_cinn_res = self.train(use_prim=True, use_cinn=True)

for i in range(len(dy_res)):
np.testing.assert_allclose(
comp_st_cinn_res[i].numpy(),
dy_res[i].numpy(),
rtol=1e-7,
atol=1e-7,
)
def test_tanh_grad_comp(self):
paddle.enable_static()

def test_tanh_grad_comp(self):
def actual(primal0, primal1):
core._set_prim_backward_enabled(True)
mp, sp = paddle.static.Program(), paddle.static.Program()
Expand All @@ -124,7 +112,7 @@ def actual(primal0, primal1):
'primal0': primal0,
'primal1': primal1,
},
fetch_list=[res[0].name, res[1].name],
fetch_list=[res[0], res[1]],
)
return out[0], out[1]

Expand All @@ -150,7 +138,7 @@ def desired(primal0, primal1):
'primal0': self.primal0,
'primal1': self.primal1,
},
fetch_list=[res[0].name, res[1].name],
fetch_list=[res[0], res[1]],
)
return out[0], out[1]

Expand All @@ -171,6 +159,7 @@ def desired(primal0, primal1):
atol=0,
)
core._set_prim_backward_enabled(False)
paddle.disable_static()


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def actual(primal, axis, cotangent):
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=[x_cotangent[0].name],
fetch_list=[x_cotangent[0]],
)[0]

def desired(primal, axis, cotangent):
Expand All @@ -207,7 +207,7 @@ def desired(primal, axis, cotangent):
return exe.run(
program=mp,
feed={'primal': primal, 'cotangent': cotangent},
fetch_list=[x_cotangent[0].name],
fetch_list=[x_cotangent[0]],
)[0]

if (self.dtype == np.float16) and isinstance(
Expand Down