Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions paddle/framework/op_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very bad code here

Since the op_kernel.h cannot be include by any order.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm op_kernel.h since the kernel type will be moved to op_kernel_type.h

namespace paddle {
namespace framework {

// define some kernel hint
const std::string kForceCPU = "force_cpu";
const std::string kUseCUDNN = "use_cudnn";
const std::string kUseMKLDNN = "use_mkldnn";

struct OpKernelType {
struct Hash {
std::hash<int> hash_;
size_t operator()(const OpKernelType& key) const {
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
(place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
return hash_(pre_hash);
}
};

platform::Place place_;
proto::DataType data_type_;

OpKernelType(proto::DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {}

OpKernelType(proto::DataType data_type,
const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}

bool operator==(const OpKernelType& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_;
}
};
}
}
17 changes: 13 additions & 4 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -404,19 +404,28 @@ void OperatorWithKernel::Run(const Scope& scope,

// check if op[type] have kernel for kernel_key
OpKernelMap& kernels = kernels_iter->second;
auto kernel_key = GetKernelType(ctx);
auto kernel_iter = kernels.find(kernel_key);
auto actual_kernel_key = GetActualKernelType(ctx);
auto expected_kernel_key = GetExpectedKernelType(actual_kernel_key);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (actual_kernel_key != expected_kernel_key) {
    TransformInputs(ctx, from=actual_kernel_key, to=expected_kernel_key);
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will be done in next PR~

auto kernel_iter = kernels.find(expected_kernel_key);

if (kernel_iter == kernels.end()) {
PADDLE_THROW("The operator %s does not support %s", type_, kernel_key);
PADDLE_THROW("The operator %s does not support %s", type_,
expected_kernel_key);
}

kernel_iter->second->Compute(ctx);
}
OpKernelType OperatorWithKernel::GetKernelType(

OpKernelType OperatorWithKernel::GetActualKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}

OpKernelType OperatorWithKernel::GetExpectedKernelType(
const OpKernelType& actual_kernel_type) const {
return actual_kernel_type;
}

proto::DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
Expand Down
33 changes: 4 additions & 29 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_kernel.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor.h"
Expand Down Expand Up @@ -345,34 +346,6 @@ class OpKernel : public OpKernelBase {
using ELEMENT_TYPE = T;
};

struct OpKernelType {
struct Hash {
std::hash<int> hash_;
size_t operator()(const OpKernelType& key) const {
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
(place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
return hash_(pre_hash);
}
};

platform::Place place_;
proto::DataType data_type_;

OpKernelType(proto::DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {}

OpKernelType(proto::DataType data_type,
const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}

bool operator==(const OpKernelType& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_;
}
};

class OperatorWithKernel : public OperatorBase {
public:
using OpKernelMap =
Expand Down Expand Up @@ -405,7 +378,9 @@ class OperatorWithKernel : public OperatorBase {
}

protected:
virtual OpKernelType GetKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetActualKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetExpectedKernelType(
const OpKernelType& actual_kernel_type) const;

private:
// indicate kernel DataType by input data. Defaultly all input data must be
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class OpWithKernelTest : public OperatorWithKernel {

protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
OpKernelType GetActualKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
}
};
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/accuracy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/auc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class AucOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/chunk_eval_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32,
ctx.device_context());
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/compare_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ class CompareOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
framework::OpKernelType kt = OperatorWithKernel::GetActualKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
Expand Down
8 changes: 7 additions & 1 deletion paddle/operators/crf_decoding_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,18 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
ctx.device_context());
}

framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& actual_kernel_type) const override {
return framework::OpKernelType(actual_kernel_type.data_type_,
platform::CPUPlace());
}
};
} // namespace operators
} // namespace paddle
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/cross_entropy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
Expand Down Expand Up @@ -101,7 +101,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/fill_constant_batch_size_like_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/gather_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class GatherOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
Expand All @@ -57,7 +57,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/gaussian_random_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/linear_chain_crf_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of linear_chain_crf
// is determined by its input "Emission".
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
Expand Down Expand Up @@ -242,7 +242,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of output of the linear_chain_crf_grad
// operator is determined by its input: gradients of LogLikelihood.
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/lod_reset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
Expand Down Expand Up @@ -97,7 +97,7 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/logical_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ class LogicalOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
framework::OpKernelType kt = OperatorWithKernel::GetActualKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/lookup_table_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
Expand Down Expand Up @@ -98,7 +98,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/lstm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class LSTMOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
Expand Down Expand Up @@ -260,7 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/multiplex_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
Expand Down Expand Up @@ -102,7 +102,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/nce_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class NCEOp : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
Expand Down Expand Up @@ -166,7 +166,7 @@ class NCEOpGrad : public framework::OperatorWithKernel {
}

protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
Expand Down
Loading