@@ -15,12 +15,14 @@ limitations under the License. */
1515#include < random>
1616
1717#include " paddle/fluid/framework/generator.h"
18+ #include " paddle/fluid/framework/infershape_utils.h"
1819#include " paddle/fluid/framework/op_registry.h"
1920#include " paddle/fluid/framework/op_version_registry.h"
2021#include " paddle/fluid/operators/fill_constant_op.h"
2122#ifdef PADDLE_WITH_MKLDNN
2223#include " paddle/fluid/platform/mkldnn_helper.h"
2324#endif
25+ #include " paddle/phi/infermeta/nullary.h"
2426
2527namespace paddle {
2628namespace operators {
@@ -54,38 +56,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
5456 public:
5557 using framework::OperatorWithKernel::OperatorWithKernel;
5658
57- void InferShape (framework::InferShapeContext* ctx) const override {
58- OP_INOUT_CHECK (ctx->HasOutput (" Out" ), " Output" , " Out" , " GaussianRandom" );
59-
60- auto shape = ctx->Attrs ().Get <std::vector<int64_t >>(" shape" );
61- std::vector<int64_t > temp;
62- temp.reserve (shape.size ());
63- for (auto dim : shape) {
64- temp.push_back (static_cast <int64_t >(dim));
65- }
66- if (shape.empty () && ctx->HasInput (" ShapeTensor" )) {
67- auto shape_dims = ctx->GetInputDim (" ShapeTensor" );
68- int num_ele = 1 ;
69- for (int i = 0 ; i < shape_dims.size (); ++i) {
70- num_ele *= shape_dims[i];
71- }
72- auto vec_dims = std::vector<int >(num_ele, -1 );
73- ctx->SetOutputDim (" Out" , phi::make_ddim (vec_dims));
74-
75- return ;
76- }
77- if (!ctx->HasInput (" ShapeTensor" ) && !ctx->HasInputs (" ShapeTensorList" )) {
78- PADDLE_ENFORCE_GT (
79- shape.size (), 0UL ,
80- platform::errors::InvalidArgument (
81- " Attribute(shape) of GaussianRandomOp must be set "
82- " and shape.size() > 0, but reveived shape.size() is %d" ,
83- shape.size ()));
84- }
85-
86- ctx->SetOutputDim (" Out" , phi::make_ddim (temp));
87- }
88-
8959 protected:
9060 framework::OpKernelType GetExpectedKernelType (
9161 const framework::ExecutionContext& ctx) const override {
@@ -171,11 +141,20 @@ Used to initialize tensors with gaussian random generator.
171141} // namespace paddle
172142
173143namespace ops = paddle::operators;
174- REGISTER_OP_WITHOUT_GRADIENT (gaussian_random, ops::GaussianRandomOp,
175- ops::GaussianRandomOpMaker);
144+
145+ DECLARE_INFER_SHAPE_FUNCTOR (gaussian_random, GaussianRandomInferShapeFunctor,
146+ PD_INFER_META (phi::GaussianRandomInferMeta));
147+
148+ REGISTER_OPERATOR (
149+ gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker,
150+ paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
151+ paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
152+ GaussianRandomInferShapeFunctor);
153+
176154REGISTER_OP_CPU_KERNEL (gaussian_random_batch_size_like,
177155 ops::CPUGaussianRandomBatchSizeLikeKernel<float >,
178156 ops::CPUGaussianRandomBatchSizeLikeKernel<double >);
157+
179158REGISTER_OP_VERSION (gaussian_random)
180159 .AddCheckpoint(
181160 R"ROC(
0 commit comments