-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Closed
Description
关于LoDTensor在operator里的用法,合入#4083 之后, 对于非sequence operators使用,需要注意:
- 只需要在
InferShape里对输出用:Output<framework::LoDTensor> - Input以及operator kernel里还可继续用
Tensor。
用SigmoidOp举例:
class SigmoidOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of SigmoidOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"),
"Output(Y) of SigmoidOp should not be null.");
ctx.Output<framework::LoDTensor>("Y")->Resize(
ctx.Input<Tensor>("X")->dims());
}
};SigmoidKernel:
template <typename Place, typename T>
class SigmoidKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<Tensor>("X");
auto output = context.Output<Tensor>("Y");
output->mutable_data<T>(context.GetPlace());
// The clipping is used in Paddle's raw implenmention
auto X = EigenVector<T>::Flatten(*input);
auto Y = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>();
Y.device(place) = 1. / (1. + (-X).exp());
}
};有同学觉得一些地方使用LoDTensor, 一些地方用Tensor,这样有点晕,大家怎么看?
如果要改的话,有两种方式:
-
所有operator的任何地方都用
LoDTenor,(王益老师也强调,LoDTensor是我们的一个特点)- 会导致所有operators,即使非sequence operators,任何地方都感知
LoDTensor.
- 会导致所有operators,即使非sequence operators,任何地方都感知
-
InferShapeContext::Output<T>(const std::string& name)函数对Tensor特化,GetMutable<LoDTensor>始终使用LoDTensor:- 会导致通过
InferShapeContext::Output<Tensor>()永远无法在scope里创造出Tensor.
- 会导致通过
template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto* var = OutputVar(name);
return var == nullptr ? nullptr : var->GetMutable<LoDTensor>();
}大家有什么建议吗?
Metadata
Metadata
Labels
No labels