Skip to content

Commit 409ac4a

Browse files
authored
Merge pull request #3819 from jacquesqiao/add-getop-to-ctx
add op() to InferShapeContext
2 parents c1feb27 + d323831 commit 409ac4a

9 files changed

Lines changed: 32 additions & 26 deletions

File tree

paddle/framework/operator.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,15 @@ class InferShapeContext {
233233
InferShapeContext(const OperatorBase& op, const Scope& scope)
234234
: op_(op), scope_(scope) {}
235235

236+
const OperatorBase& op() const { return op_; }
237+
238+
const Scope& scope() const { return scope_; }
239+
240+
template <typename T>
241+
inline const T& GetAttr(const std::string& name) const {
242+
return op_.GetAttr<T>(name);
243+
}
244+
236245
size_t InputSize(const std::string& name) const {
237246
return op_.Inputs(name).size();
238247
}
@@ -314,6 +323,7 @@ class InferShapeContext {
314323
return res;
315324
}
316325

326+
private:
317327
const OperatorBase& op_;
318328
const Scope& scope_;
319329
};

paddle/framework/operator_test.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ class CPUKernelTest : public OpKernel {
122122
public:
123123
void Compute(const ExecutionContext& ctx) const {
124124
std::cout << "this is cpu kernel" << std::endl;
125-
std::cout << ctx.op_.DebugString() << std::endl;
125+
std::cout << ctx.op().DebugString() << std::endl;
126126
cpu_kernel_run_num++;
127-
ASSERT_EQ(ctx.op_.Input("x"), "IN1");
128-
ASSERT_EQ(ctx.op_.Output("y"), "OUT1");
127+
ASSERT_EQ(ctx.op().Input("x"), "IN1");
128+
ASSERT_EQ(ctx.op().Output("y"), "OUT1");
129129
}
130130
};
131131

@@ -148,7 +148,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
148148
class CPUKernalMultiInputsTest : public OpKernel {
149149
public:
150150
void Compute(const ExecutionContext& ctx) const {
151-
auto xs = ctx.op_.Inputs("xs");
151+
auto xs = ctx.op().Inputs("xs");
152152
ASSERT_EQ(xs.size(), 3UL);
153153
ASSERT_EQ(xs[0], "x0");
154154
ASSERT_EQ(xs[1], "x1");
@@ -172,10 +172,10 @@ class CPUKernalMultiInputsTest : public OpKernel {
172172
auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
173173
ASSERT_EQ(outTensor0.size(), 2U);
174174

175-
auto k = ctx.op_.Input("k");
175+
auto k = ctx.op().Input("k");
176176
ASSERT_EQ(k, "k0");
177177

178-
auto ys = ctx.op_.Outputs("ys");
178+
auto ys = ctx.op().Outputs("ys");
179179
ASSERT_EQ(ys.size(), 2UL);
180180
ASSERT_EQ(ys[0], "y0");
181181
ASSERT_EQ(ys[1], "y1");

paddle/operators/gaussian_random_op.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,12 @@ template <typename T>
1919
class CPUGaussianRandomKernel : public framework::OpKernel {
2020
public:
2121
void Compute(const framework::ExecutionContext& context) const override {
22-
float mean = context.op_.GetAttr<float>("mean");
23-
float std = context.op_.GetAttr<float>("std");
22+
float mean = context.GetAttr<float>("mean");
23+
float std = context.GetAttr<float>("std");
2424
auto* tensor = context.Output<framework::Tensor>("Out");
2525
T* data = tensor->mutable_data<T>(context.GetPlace());
2626

27-
unsigned int seed =
28-
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
27+
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
2928
std::minstd_rand engine;
3029
if (seed == 0) {
3130
seed = std::random_device()();

paddle/operators/gaussian_random_op.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,13 @@ class GPUGaussianRandomKernel : public framework::OpKernel {
4242
void Compute(const framework::ExecutionContext& context) const override {
4343
auto* tensor = context.Output<framework::Tensor>("Out");
4444
T* data = tensor->mutable_data<T>(context.GetPlace());
45-
unsigned int seed =
46-
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
45+
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
4746
if (seed == 0) {
4847
std::random_device rd;
4948
seed = rd();
5049
}
51-
T mean = static_cast<T>(context.op_.GetAttr<float>("mean"));
52-
T std = static_cast<T>(context.op_.GetAttr<float>("std"));
50+
T mean = static_cast<T>(context.GetAttr<float>("mean"));
51+
T std = static_cast<T>(context.GetAttr<float>("std"));
5352
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
5453
ssize_t N = framework::product(tensor->dims());
5554
thrust::transform(index_sequence_begin, index_sequence_begin + N,

paddle/operators/mul_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ class MulOp : public framework::OperatorWithKernel {
2929
auto dim1 = ctx.Input<Tensor>("Y")->dims();
3030
PADDLE_ENFORCE_EQ(dim0.size(), 2,
3131
"input X(%s) should be a tensor with 2 dims, a matrix",
32-
ctx.op_.Input("X"));
32+
ctx.op().Input("X"));
3333
PADDLE_ENFORCE_EQ(dim1.size(), 2,
3434
"input Y(%s) should be a tensor with 2 dims, a matrix",
35-
ctx.op_.Input("Y"));
35+
ctx.op().Input("Y"));
3636
PADDLE_ENFORCE_EQ(
3737
dim0[1], dim1[0],
3838
"First matrix's width must be equal with second matrix's height.");

paddle/operators/scale_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class ScaleKernel : public framework::OpKernel {
2727
auto* in = context.Input<framework::Tensor>("X");
2828
tensor->mutable_data<T>(in->place());
2929

30-
auto scale = static_cast<T>(context.op_.GetAttr<AttrType>("scale"));
30+
auto scale = static_cast<T>(context.GetAttr<AttrType>("scale"));
3131

3232
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
3333
auto eigen_in = framework::EigenVector<T>::Flatten(*in);

paddle/operators/sgd_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel {
3131
auto param = ctx.Input<Tensor>("param");
3232
auto grad = ctx.Input<Tensor>("grad");
3333
auto param_out = ctx.Output<Tensor>("param_out");
34-
float lr = ctx.op_.GetAttr<float>("learning_rate");
34+
float lr = ctx.GetAttr<float>("learning_rate");
3535

3636
param_out->mutable_data<T>(ctx.GetPlace());
3737

paddle/operators/uniform_random_op.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,15 @@ class CPUUniformRandomKernel : public framework::OpKernel {
2626
void Compute(const framework::ExecutionContext& context) const override {
2727
auto* tensor = context.Output<framework::Tensor>("Out");
2828
T* data = tensor->mutable_data<T>(context.GetPlace());
29-
unsigned int seed =
30-
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
29+
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
3130
std::minstd_rand engine;
3231
if (seed == 0) {
3332
seed = std::random_device()();
3433
}
3534
engine.seed(seed);
3635
std::uniform_real_distribution<T> dist(
37-
static_cast<T>(context.op_.GetAttr<float>("min")),
38-
static_cast<T>(context.op_.GetAttr<float>("max")));
36+
static_cast<T>(context.GetAttr<float>("min")),
37+
static_cast<T>(context.GetAttr<float>("max")));
3938
ssize_t size = framework::product(tensor->dims());
4039
for (ssize_t i = 0; i < size; ++i) {
4140
data[i] = dist(engine);

paddle/operators/uniform_random_op.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,13 @@ class GPUUniformRandomKernel : public framework::OpKernel {
4545
void Compute(const framework::ExecutionContext& context) const override {
4646
auto* tensor = context.Output<framework::Tensor>("Out");
4747
T* data = tensor->mutable_data<T>(context.GetPlace());
48-
unsigned int seed =
49-
static_cast<unsigned int>(context.op_.GetAttr<int>("seed"));
48+
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
5049
if (seed == 0) {
5150
std::random_device rd;
5251
seed = rd();
5352
}
54-
T min = static_cast<T>(context.op_.GetAttr<float>("min"));
55-
T max = static_cast<T>(context.op_.GetAttr<float>("max"));
53+
T min = static_cast<T>(context.GetAttr<float>("min"));
54+
T max = static_cast<T>(context.GetAttr<float>("max"));
5655
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
5756
ssize_t N = framework::product(tensor->dims());
5857
thrust::transform(index_sequence_begin, index_sequence_begin + N,

0 commit comments

Comments
 (0)