Refactoring InferShape#3946
Conversation
paddle/framework/operator.h
Outdated
| return GetDim(op_.Output(name)); | ||
| } | ||
|
|
||
| void SetOutputDim(const std::string& name, const DDim& dim) const { |
There was a problem hiding this comment.
Set methods should change private data members, so they cannot be const
paddle/framework/operator.h
Outdated
| if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) { | ||
| t = var->GetMutable<LoDTensor>(); | ||
| } else { | ||
| t = const_cast<Tensor*>(GetTensorFromVar(scope_.FindVar(name))); |
There was a problem hiding this comment.
const_cast should not be used.
paddle/framework/operator.h
Outdated
| protected: | ||
| virtual void InferShape(const InferShapeContext& ctx) const = 0; | ||
| virtual void InferShape(const InferShapeContext& ctx) const {} | ||
| virtual void InferShape(const InferShapeContextBase& ctx) const {} |
There was a problem hiding this comment.
Why there are two interfaces?
There was a problem hiding this comment.
The old one is for compiling and test, has been removed after transform all the old operators.
paddle/framework/ddim.cc
Outdated
| return result; | ||
| } | ||
|
|
||
| std::string debug_str(const DDim& ddim) { |
There was a problem hiding this comment.
I use it a lot when debug the code
There was a problem hiding this comment.
removed since we can use cout << ddim to print the debug string.
paddle/framework/operator.h
Outdated
| } | ||
|
|
||
| private: | ||
| Tensor* GetTensor(const std::string& name, bool allocate) const { |
There was a problem hiding this comment.
allocate could be a template parameter.
paddle/framework/operator.h
Outdated
| if (allocate) { | ||
| t = var->GetMutable<LoDTensor>(); | ||
| } else { | ||
| PADDLE_ENFORCE(false, "Variable(%s) should be tensor", name); |
paddle/framework/operator.h
Outdated
| const platform::DeviceContext& device_context_; | ||
| }; | ||
|
|
||
| class RunTimeInferShapeContext : public InferShapeContextBase { |
There was a problem hiding this comment.
RunTime-> Runtime
runtime is a word
|
|
||
| protected: | ||
| virtual void InferShape(const InferShapeContext& ctx) const = 0; | ||
| virtual void InferShape(InferShapeContextBase* ctx) const = 0; |
There was a problem hiding this comment.
This file should be changed a lot. void OperatorBase::InferShape(const Scope& scope) should be removed. void InferShape(InferShapeContextBase* ctx) const should be public.
There was a problem hiding this comment.
Yes, This will be done in next pr that will modify all the python related code.
fix: #4183
design: #4142
support compile time and runtime infershape.