Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,19 +402,28 @@ void OperatorWithKernel::Run(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second;

ExecutionContext ctx(*this, scope, *dev_ctx);
auto kernel_key = GetKernelType(ctx);
auto kernel_iter = kernels.find(kernel_key);
auto actual_kernel_key = GetActualKernelType(ctx);
auto expected_kernel_key = GetExpectedKernelType(actual_kernel_key);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (actual_kernel_key != expected_kernel_key) {
    TransformInputs(ctx, from=actual_kernel_key, to=expected_kernel_key);
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will be done in next PR~

auto kernel_iter = kernels.find(expected_kernel_key);

if (kernel_iter == kernels.end()) {
PADDLE_THROW("The operator %s does not support %s", type_, kernel_key);
PADDLE_THROW("The operator %s does not support %s", type_,
expected_kernel_key);
}

kernel_iter->second->Compute(ctx);
}
OpKernelType OperatorWithKernel::GetKernelType(

OpKernelType OperatorWithKernel::GetActualKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}

OpKernelType OperatorWithKernel::GetExpectedKernelType(
const OpKernelType& actual_kernel_type) const {
return actual_kernel_type;
}

proto::DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
Expand Down
9 changes: 8 additions & 1 deletion paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ constexpr char kGradVarSuffix[] = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros.
constexpr char kZeroVarSuffix[] = "@ZERO";

// define some kernel hint
const std::string kUseCPU = "use_cpu";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constexpr char kUseCPU[]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix in next pr

const std::string kUseCUDNN = "use_cudnn";
const std::string kUseMKLDNN = "use_mkldnn";

inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix;
}
Expand Down Expand Up @@ -373,7 +378,9 @@ class OperatorWithKernel : public OperatorBase {
}

protected:
virtual OpKernelType GetKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetActualKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetExpectedKernelType(
const OpKernelType& actual_kernel_type) const;

private:
// indicate kernel DataType by input data. Defaultly all input data must be
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class OpWithKernelTest : public OperatorWithKernel {

protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
OpKernelType GetActualKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
}
};
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/accuracy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/auc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class AucOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/chunk_eval_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32,
ctx.device_context());
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/compare_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ class CompareOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
framework::OpKernelType kt = OperatorWithKernel::GetActualKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
Expand Down
8 changes: 7 additions & 1 deletion paddle/operators/crf_decoding_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,18 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
ctx.device_context());
}

framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& actual_kernel_type) const override {
return framework::OpKernelType(actual_kernel_type.data_type_,
platform::CPUPlace());
}
};
} // namespace operators
} // namespace paddle
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/cross_entropy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
Expand Down Expand Up @@ -101,7 +101,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/fill_constant_batch_size_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/gather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class GatherOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
Expand All @@ -57,7 +57,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/gaussian_random_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/linear_chain_crf_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of linear_chain_crf
// is determined by its input "Emission".
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
Expand Down Expand Up @@ -242,7 +242,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of output of the linear_chain_crf_grad
// operator is determined by its input: gradients of LogLikelihood.
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/lod_reset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
Expand Down Expand Up @@ -97,7 +97,7 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/logical_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ class LogicalOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
framework::OpKernelType kt = OperatorWithKernel::GetActualKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/lookup_table_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
Expand Down Expand Up @@ -98,7 +98,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/lstm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class LSTMOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
Expand Down Expand Up @@ -260,7 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/multiplex_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
Expand Down Expand Up @@ -102,7 +102,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/nce_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class NCEOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
Expand Down Expand Up @@ -166,7 +166,7 @@ class NCEOpGrad : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/pool_with_index_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
Expand All @@ -90,7 +90,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/positive_negative_pair_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Score")->type()),
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/precision_recall_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type()),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/roi_pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class ROIPoolOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
Expand All @@ -89,7 +89,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/scatter_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class ScatterOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
Expand All @@ -68,7 +68,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/sequence_pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
Expand Down
Loading