Skip to content

Commit 23b820c

Browse files
committed
[Phi] move InferShape for truncated_gaussian_random and gaussian_random
1 parent 0fb6bca commit 23b820c

7 files changed

Lines changed: 72 additions & 62 deletions

File tree

paddle/fluid/operators/gaussian_random_op.cc

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ limitations under the License. */
1515
#include <random>
1616

1717
#include "paddle/fluid/framework/generator.h"
18+
#include "paddle/fluid/framework/infershape_utils.h"
1819
#include "paddle/fluid/framework/op_registry.h"
1920
#include "paddle/fluid/framework/op_version_registry.h"
2021
#include "paddle/fluid/operators/fill_constant_op.h"
2122
#ifdef PADDLE_WITH_MKLDNN
2223
#include "paddle/fluid/platform/mkldnn_helper.h"
2324
#endif
25+
#include "paddle/phi/infermeta/nullary.h"
2426

2527
namespace paddle {
2628
namespace operators {
@@ -54,38 +56,6 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
5456
public:
5557
using framework::OperatorWithKernel::OperatorWithKernel;
5658

57-
void InferShape(framework::InferShapeContext* ctx) const override {
58-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GaussianRandom");
59-
60-
auto shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
61-
std::vector<int64_t> temp;
62-
temp.reserve(shape.size());
63-
for (auto dim : shape) {
64-
temp.push_back(static_cast<int64_t>(dim));
65-
}
66-
if (shape.empty() && ctx->HasInput("ShapeTensor")) {
67-
auto shape_dims = ctx->GetInputDim("ShapeTensor");
68-
int num_ele = 1;
69-
for (int i = 0; i < shape_dims.size(); ++i) {
70-
num_ele *= shape_dims[i];
71-
}
72-
auto vec_dims = std::vector<int>(num_ele, -1);
73-
ctx->SetOutputDim("Out", phi::make_ddim(vec_dims));
74-
75-
return;
76-
}
77-
if (!ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList")) {
78-
PADDLE_ENFORCE_GT(
79-
shape.size(), 0UL,
80-
platform::errors::InvalidArgument(
81-
"Attribute(shape) of GaussianRandomOp must be set "
82-
"and shape.size() > 0, but reveived shape.size() is %d",
83-
shape.size()));
84-
}
85-
86-
ctx->SetOutputDim("Out", phi::make_ddim(temp));
87-
}
88-
8959
protected:
9060
framework::OpKernelType GetExpectedKernelType(
9161
const framework::ExecutionContext& ctx) const override {
@@ -171,11 +141,20 @@ Used to initialize tensors with gaussian random generator.
171141
} // namespace paddle
172142

173143
namespace ops = paddle::operators;
174-
REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
175-
ops::GaussianRandomOpMaker);
144+
145+
DECLARE_INFER_SHAPE_FUNCTOR(gaussian_random, GaussianRandomInferShapeFunctor,
146+
PD_INFER_META(phi::GaussianRandomInferMeta));
147+
148+
REGISTER_OPERATOR(
149+
gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker,
150+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
151+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
152+
GaussianRandomInferShapeFunctor);
153+
176154
REGISTER_OP_CPU_KERNEL(gaussian_random_batch_size_like,
177155
ops::CPUGaussianRandomBatchSizeLikeKernel<float>,
178156
ops::CPUGaussianRandomBatchSizeLikeKernel<double>);
157+
179158
REGISTER_OP_VERSION(gaussian_random)
180159
.AddCheckpoint(
181160
R"ROC(

paddle/fluid/operators/truncated_gaussian_random_op.cc

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ limitations under the License. */
1717
#include <vector>
1818

1919
#include "paddle/fluid/framework/generator.h"
20+
#include "paddle/fluid/framework/infershape_utils.h"
2021
#include "paddle/fluid/framework/op_registry.h"
2122
#include "paddle/fluid/operators/truncated_gaussian_random_op.h"
23+
#include "paddle/phi/infermeta/nullary.h"
2224

2325
namespace paddle {
2426
namespace operators {
@@ -27,26 +29,6 @@ class TruncatedGaussianRandomOp : public framework::OperatorWithKernel {
2729
public:
2830
using framework::OperatorWithKernel::OperatorWithKernel;
2931

30-
void InferShape(framework::InferShapeContext* ctx) const override {
31-
PADDLE_ENFORCE_EQ(
32-
ctx->HasOutput("Out"), true,
33-
platform::errors::NotFound(
34-
"Output(Out) of TruncatedGaussianRandomOp should not be null."));
35-
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
36-
std::vector<int64_t> out_dim;
37-
out_dim.reserve(shape.size());
38-
for (auto dim : shape) {
39-
out_dim.push_back(static_cast<int64_t>(dim));
40-
}
41-
PADDLE_ENFORCE_GT(
42-
shape.size(), 0UL,
43-
platform::errors::InvalidArgument(
44-
"the input shape of TruncatedGaussianRandomOp must be set, "
45-
"But the rank of shape we received is %d",
46-
shape.size()));
47-
ctx->SetOutputDim("Out", phi::make_ddim(out_dim));
48-
}
49-
5032
protected:
5133
framework::OpKernelType GetExpectedKernelType(
5234
const framework::ExecutionContext& ctx) const override {
@@ -99,6 +81,14 @@ Used to initialize tensors with truncated gaussian random generator.
9981
} // namespace paddle
10082

10183
namespace ops = paddle::operators;
102-
REGISTER_OP_WITHOUT_GRADIENT(truncated_gaussian_random,
103-
ops::TruncatedGaussianRandomOp,
104-
ops::TruncatedGaussianRandomOpMaker);
84+
85+
DECLARE_INFER_SHAPE_FUNCTOR(
86+
truncated_gaussian_random, TruncatedGaussianRandomInferShapeFunctor,
87+
PD_INFER_META(phi::TruncatedGaussianRandomInferMeta));
88+
89+
REGISTER_OPERATOR(
90+
truncated_gaussian_random, ops::TruncatedGaussianRandomOp,
91+
ops::TruncatedGaussianRandomOpMaker,
92+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
93+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
94+
TruncatedGaussianRandomInferShapeFunctor);

paddle/phi/infermeta/nullary.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,29 @@ void EyeInferMeta(int64_t num_rows,
4040
out->set_dims({num_rows, num_columns});
4141
out->set_dtype(dtype);
4242
}
43+
44+
void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
45+
float mean,
46+
float std,
47+
int seed,
48+
DataType dtype,
49+
MetaTensor* out) {
50+
auto out_dims = phi::make_ddim(shape);
51+
out->set_dims(out_dims);
52+
out->set_dtype(dtype);
53+
out->set_layout(DataLayout::NCHW);
54+
}
55+
56+
void GaussianRandomInferMeta(const ScalarArray& shape,
57+
float mean,
58+
float std,
59+
int seed,
60+
DataType dtype,
61+
MetaTensor* out) {
62+
auto out_dims = phi::make_ddim(shape.GetData());
63+
out->set_dims(out_dims);
64+
out->set_dtype(dtype);
65+
out->set_layout(DataLayout::NCHW);
66+
}
67+
4368
} // namespace phi

paddle/phi/infermeta/nullary.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,18 @@ void EyeInferMeta(int64_t num_rows,
4040
DataType dtype,
4141
MetaTensor* out);
4242

43+
void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
44+
float mean,
45+
float std,
46+
int seed,
47+
DataType dtype,
48+
MetaTensor* out);
49+
50+
void GaussianRandomInferMeta(const ScalarArray& shape,
51+
float mean,
52+
float std,
53+
int seed,
54+
DataType dtype,
55+
MetaTensor* out);
56+
4357
} // namespace phi

paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace phi {
2727

2828
template <typename T, typename Context>
2929
void TruncatedGaussianRandomKernel(const Context& dev_ctx,
30-
const ScalarArray& shape,
30+
const std::vector<int>& shape,
3131
float mean,
3232
float std,
3333
int seed,

paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "paddle/phi/core/kernel_registry.h"
2626

2727
#include "paddle/fluid/framework/generator.h"
28-
// #include "paddle/phi/core/generator.h"
2928

3029
namespace phi {
3130

@@ -87,7 +86,7 @@ struct TruncatedNormalOffset {
8786

8887
template <typename T, typename Context>
8988
void TruncatedGaussianRandomKernel(const Context& dev_ctx,
90-
const ScalarArray& shape,
89+
const std::vector<int>& shape,
9190
float mean,
9291
float std,
9392
int seed,

paddle/phi/kernels/truncated_gaussian_random_kernel.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
#include "paddle/phi/core/dense_tensor.h"
2222
#include "paddle/phi/core/device_context.h"
2323

24+
#include "paddle/phi/infermeta/nullary.h"
25+
#include "paddle/phi/kernels/empty_kernel.h"
26+
2427
namespace phi {
2528

2629
// reference: https://gist.github.com/lakshayg/d80172fe5ae3c5d2c2aedb53c250320e
@@ -157,8 +160,8 @@ struct TruncatedNormal {
157160
};
158161

159162
template <typename T, typename Context>
160-
void TruncatedGaussianRandomKernel(const Context& ctx,
161-
const ScalarArray& shape,
163+
void TruncatedGaussianRandomKernel(const Context& dev_ctx,
164+
const std::vector<int>& shape,
162165
float mean,
163166
float std,
164167
int seed,

0 commit comments

Comments
 (0)