diff --git a/test/xpu/collective_concat_op.py b/test/xpu/collective_concat_op.py index ebda6c3883f6a8..e8a3dbe74524d0 100644 --- a/test/xpu/collective_concat_op.py +++ b/test/xpu/collective_concat_op.py @@ -16,7 +16,7 @@ import paddle from paddle import base -from paddle.base import core, layers +from paddle.base import core paddle.enable_static() @@ -25,16 +25,16 @@ class TestCollectiveConcat(TestCollectiveRunnerBase): def __init__(self): self.global_ring_id = 0 - def get_model(self, main_prog, startup_program): + def get_model(self, main_prog, startup_program, dtype='float32'): ring_id = 0 nranks = 2 with base.program_guard(main_prog, startup_program): - tindata = layers.data( - name="tindata", shape=[10, 1000], dtype='float32' + tindata = paddle.static.data( + name="tindata", shape=[10, 1000], dtype=dtype ) toutdata = main_prog.current_block().create_var( name="outofconcat", - dtype='float32', + dtype=dtype, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=False, diff --git a/test/xpu/collective_split_op.py b/test/xpu/collective_split_op.py index 82f8db770e524a..966df54b8fe2b7 100644 --- a/test/xpu/collective_split_op.py +++ b/test/xpu/collective_split_op.py @@ -16,7 +16,7 @@ import paddle from paddle import base -from paddle.base import core, layers +from paddle.base import core paddle.enable_static() @@ -25,16 +25,16 @@ class TestCollectiveAllGather(TestCollectiveRunnerBase): def __init__(self): self.global_ring_id = 0 - def get_model(self, main_prog, startup_program): + def get_model(self, main_prog, startup_program, dtype='float32'): ring_id = 0 nranks = 2 with base.program_guard(main_prog, startup_program): - tindata = layers.data( - name="tindata", shape=[10, 1000], dtype='float32' + tindata = paddle.static.data( + name="tindata", shape=[10, 1000], dtype=dtype ) toutdata = main_prog.current_block().create_var( name="outofsplit", - dtype='float32', + dtype=dtype, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, stop_gradient=False, diff --git a/test/xpu/test_collective_softmax_with_cross_entropy_xpu.py b/test/xpu/test_collective_softmax_with_cross_entropy_xpu.py index 5935785fba50d3..21d333d222ef98 100644 --- a/test/xpu/test_collective_softmax_with_cross_entropy_xpu.py +++ b/test/xpu/test_collective_softmax_with_cross_entropy_xpu.py @@ -95,7 +95,7 @@ def check_with_place( self, model_file, col_type, - dtype, + dtype=None, check_error_log=False, need_envs={}, ): @@ -105,8 +105,9 @@ def check_with_place( "PYTHONPATH": os.getenv("PYTHONPATH", ""), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "LD_PRELOAD": os.getenv("LD_PRELOAD", ""), - "GLOG_v": "0", + "GLOG_v": "3", "DTYPE": dtype, + "FLAGS_dynamic_static_unified_comm": "0", } required_envs.update(need_envs) if check_error_log: