Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ void HandleForSpecialOp(pir::Operation* op,
if (place.GetType() == phi::AllocationType::UNDEFINED) {
place = phi::CPUPlace();
}
if (phi::product(dim) >= 0) {
if (!common::contain_unknown_dim(dim)) {
phi::DenseTensorMeta meta(dtype.data(), dim);
t->set_meta(meta);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1983,7 +1983,7 @@ struct SelectInputOpTranscriber : public OpTranscriber {
undefine_value.defining_op()->set_attribute(
"dtype",
dialect::DataTypeAttribute::get(
ctx, PirTypeToPhiDType(undefined_var_type.dtype())));
ctx, dialect::TransToPhiDataType(undefined_var_type.dtype())));
auto& attribute_translator = AttributeTranslator::instance();
undefine_value.defining_op()->set_attribute(
"shape",
Expand Down
31 changes: 0 additions & 31 deletions paddle/fluid/ir_adaptor/translator/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,36 +105,5 @@ std::vector<std::string> CheckUnregisteredOperation(
return unregistered_ops;
}

phi::DataType PirTypeToPhiDType(pir::Type type) {
if (type.isa<pir::UInt8Type>()) {
return phi::DataType::UINT8;
} else if (type.isa<pir::Int8Type>()) {
return phi::DataType::INT8;
} else if (type.isa<pir::Int16Type>()) {
return phi::DataType::INT16;
} else if (type.isa<pir::Int32Type>()) {
return phi::DataType::INT32;
} else if (type.isa<pir::Int64Type>()) {
return phi::DataType::INT64;
} else if (type.isa<pir::Float32Type>()) {
return phi::DataType::FLOAT32;
} else if (type.isa<pir::Float64Type>()) {
return phi::DataType::FLOAT64;
} else if (type.isa<pir::BoolType>()) {
return phi::DataType::BOOL;
} else if (type.isa<pir::Float16Type>()) {
return phi::DataType::FLOAT16;
} else if (type.isa<pir::BFloat16Type>()) {
return phi::DataType::BFLOAT16;
} else if (type.isa<pir::Complex64Type>()) {
return phi::DataType::COMPLEX64;
} else if (type.isa<pir::Complex128Type>()) {
return phi::DataType::COMPLEX128;
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported pirType `%s` when casting it into phi::DataType.", type));
}
}

} // namespace translator
} // namespace paddle
2 changes: 0 additions & 2 deletions paddle/fluid/ir_adaptor/translator/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,5 @@ inline DataType VarTypeToDataType(
}
}

phi::DataType PirTypeToPhiDType(pir::Type type);

} // namespace translator
} // namespace paddle
3 changes: 2 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,8 @@
kernel :
func : dropout
data_type : x
optional : seed_tensor, mask
optional : seed_tensor
intermediate : mask
backward : dropout_grad

- op : einsum
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@
func : dropout
data_type : x
optional : seed_tensor
intermediate : mask
backward : dropout_grad

- op : einsum
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,7 +1137,7 @@ def dropout(
if default_main_program().random_seed != 0:
seed = default_main_program().random_seed

out, mask = _C_ops.dropout(
out = _C_ops.dropout(
x,
None,
p,
Expand Down
96 changes: 46 additions & 50 deletions test/legacy_test/test_dropout_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from utils import static_guard

import paddle
from paddle import _C_ops, base, static
from paddle import base, static
from paddle.autograd.ir_backward import grad
from paddle.base import Program, Scope, core, program_guard
from paddle.base.executor import scope_guard
Expand Down Expand Up @@ -79,6 +79,9 @@ def setUp(self):
'Out': self.inputs['X'],
'Mask': np.ones((32, 64)).astype('uint8'),
}
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.
# Because prim op compare res with dygraph
# when p = 0 dropout api return x,in dygraph mode x_grad = out_grad,
# but in static mode x_grad = []
Expand Down Expand Up @@ -108,6 +111,9 @@ def setUp(self):
# when p = 0 dropout api return x,in dygraph mode x_grad = out_grad,
# but in static mode x_grad = []
self.enable_check_static_comp = False
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.


class TestDropoutOpInput1d(OpTest):
Expand All @@ -122,6 +128,9 @@ def setUp(self):
'Out': self.inputs['X'],
'Mask': np.ones(2000).astype('uint8'),
}
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.
# Because prim op compare res with dygraph
# when p = 0 dropout api return x,in dygraph mode x_grad = out_grad,
# but in static mode x_grad = []
Expand All @@ -147,6 +156,9 @@ def setUp(self):
'Out': np.zeros((32, 64)).astype('float32'),
'Mask': np.zeros((32, 64)).astype('uint8'),
}
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.


class TestDropoutOp2_ZeroDim(TestDropoutOp2):
Expand All @@ -161,6 +173,9 @@ def setUp(self):
'Out': np.zeros(()).astype('float32'),
'Mask': np.zeros(()).astype('uint8'),
}
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.


class TestDropoutOp3(TestDropoutOp):
Expand All @@ -179,6 +194,9 @@ def setUp(self):
# when p = 0 dropout api return x,in dygraph mode x_grad = out_grad,
# but in static mode x_grad = []
self.enable_check_static_comp = False
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.


@skip_check_grad_ci(reason="For inference, check_grad is not required.")
Expand All @@ -193,6 +211,9 @@ def setUp(self):
self.outputs = {
'Out': self.inputs['X'] * (1.0 - self.attrs['dropout_prob'])
}
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.

def test_check_output(self):
self.check_output(check_prim=True, check_prim_pir=True, check_pir=True)
Expand All @@ -210,6 +231,9 @@ def setUp(self):
self.outputs = {
'Out': self.inputs['X'] * (1.0 - self.attrs['dropout_prob'])
}
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.

def test_check_output(self):
self.check_output(check_prim=True, check_prim_pir=True, check_pir=True)
Expand All @@ -232,6 +256,9 @@ def setUp(self):
'Out': np.zeros((32, 64)).astype('float32'),
'Mask': np.zeros((32, 64)).astype('uint8'),
}
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.


class TestDropoutOp7(TestDropoutOp):
Expand All @@ -255,6 +282,9 @@ def setUp(self):
# when p = 0 dropout api return x,in dygraph mode x_grad = out_grad,
# but in static mode x_grad = []
self.enable_check_static_comp = False
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.


@skip_check_grad_ci(reason="For inference, check_grad is not required.")
Expand All @@ -272,6 +302,9 @@ def setUp(self):
'dropout_implementation': 'upscale_in_train',
}
self.outputs = {'Out': self.inputs['X']}
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.

def test_check_output(self):
self.check_output(check_prim=True, check_prim_pir=True, check_pir=True)
Expand All @@ -291,6 +324,9 @@ def setUp(self):
'dropout_implementation': 'upscale_in_train',
}
self.outputs = {'Out': self.inputs['X']}
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.

def test_check_output(self):
self.check_output(check_prim=True, check_prim_pir=True, check_pir=True)
Expand All @@ -313,6 +349,9 @@ def setUp(self):
'Out': self.inputs['X'],
'Mask': np.ones((32, 64)).astype('uint8'),
}
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.
# Because prim op compare res with dygraph
# when p = 0 dropout api return x,in dygraph mode x_grad = out_grad,
# but in static mode x_grad = []
Expand Down Expand Up @@ -355,6 +394,9 @@ def setUp(self):
'is_test': True,
}
self.outputs = {'Out': out}
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.

def init_test_case(self):
self.input_size = [32, 64]
Expand Down Expand Up @@ -404,6 +446,9 @@ def setUp(self):
),
'Mask': np.zeros((32, 64)).astype('uint8'),
}
self.python_out_sig = [
"Out"
] # python out sig is customized output signature.

def test_check_output(self):
self.check_output(check_prim=True, check_prim_pir=True, check_pir=True)
Expand Down Expand Up @@ -1328,55 +1373,6 @@ def cal_grad_upscale_train(self, mask, prob):
def cal_grad_downscale_in_infer(self, mask):
return mask.astype("float32")

def test_backward_downscale_in_infer(self):
for place in self.places:
with base.dygraph.guard(place):
input = paddle.uniform([40, 40], dtype="float32")
input.stop_gradient = False
out, mask = _C_ops.dropout(
input, None, 0.5, False, "downgrade_in_infer", 0, False
)
out.backward()

np.testing.assert_array_equal(
input.gradient(),
self.cal_grad_downscale_in_infer(mask.numpy()),
)

def test_backward_upscale_train(self):
for place in self.places:
with base.dygraph.guard(place):
prob = 0.5
input = paddle.uniform([40, 40], dtype="float32")
input.stop_gradient = False
out, mask = _C_ops.dropout(
input, None, 0.5, False, "upscale_in_train", 0, False
)
out.backward()

np.testing.assert_allclose(
input.gradient(),
self.cal_grad_upscale_train(mask.numpy(), prob),
rtol=1e-05,
)

def test_backward_upscale_train_2(self):
for place in self.places:
with base.dygraph.guard(place):
prob = 0.3
input = paddle.uniform([40, 40], dtype="float32")
input.stop_gradient = False
out, mask = _C_ops.dropout(
input, None, 0.3, False, "upscale_in_train", 0, False
)
out.backward()

np.testing.assert_allclose(
input.gradient(),
self.cal_grad_upscale_train(mask.numpy(), prob),
rtol=1e-05,
)


class TestDropOutWithProbTensor(unittest.TestCase):
def setUp(self):
Expand Down