-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[PHI]Add new Tensor type and migrate save_combine kernel #47856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
60f8c2a
60e3717
ebbee9a
a87999d
9929a28
c4a66e6
ee11ac6
01e976c
bec4990
4010a6f
4ec7569
c4cca54
9bbd512
e552774
059e85f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/phi/core/extended_tensor.h" | ||
|
|
||
| namespace phi { | ||
|
|
||
| int64_t ExtendedTensor::numel() const { | ||
| PADDLE_THROW(phi::errors::Unavailable( | ||
| "ExtendedTensor does not support `numel` method.")); | ||
| } | ||
|
|
||
| const DDim& ExtendedTensor::dims() const { | ||
| PADDLE_THROW(phi::errors::Unavailable( | ||
| "ExtendedTensor does not support `dims` method.")); | ||
| } | ||
|
|
||
| const Place& ExtendedTensor::place() const { | ||
| PADDLE_THROW(phi::errors::Unavailable( | ||
| "ExtendedTensor does not support `place` method.")); | ||
| } | ||
|
|
||
| DataType ExtendedTensor::dtype() const { | ||
| PADDLE_THROW(phi::errors::Unavailable( | ||
| "ExtendedTensor does not support `dtype` method.")); | ||
| } | ||
|
|
||
| DataLayout ExtendedTensor::layout() const { | ||
| PADDLE_THROW(phi::errors::Unavailable( | ||
| "ExtendedTensor does not support `dtype` method.")); | ||
| } | ||
|
|
||
| bool ExtendedTensor::valid() const { | ||
| PADDLE_THROW(phi::errors::Unavailable( | ||
| "ExtendedTensor does not support `valid` method.")); | ||
| } | ||
|
|
||
| bool ExtendedTensor::initialized() const { | ||
| PADDLE_THROW(phi::errors::Unavailable( | ||
| "ExtendedTensor does not support `initialized` method.")); | ||
| } | ||
|
|
||
| void* ExtendedTensor::AllocateFrom(Allocator* allocator, | ||
| DataType dtype, | ||
| size_t requested_size) { | ||
| PADDLE_THROW(phi::errors::Unavailable( | ||
| "ExtendedTensor does not support `AllocateFrom` method.")); | ||
| } | ||
|
|
||
| } // namespace phi | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "paddle/phi/core/allocator.h" | ||
| #include "paddle/phi/core/tensor_base.h" | ||
| #include "paddle/phi/core/tensor_meta.h" | ||
|
|
||
| namespace phi { | ||
|
|
||
| /// \brief The ExtendedTensor is a interface for custom designed class. | ||
| /// If you want to pass some self-designed data as input/output to kernels, | ||
| /// you can inherit from this class to store your self-designed data. | ||
| class ExtendedTensor : public TensorBase, | ||
|
||
| public TypeInfoTraits<TensorBase, ExtendedTensor> { | ||
| public: | ||
| ExtendedTensor() = default; | ||
| virtual ~ExtendedTensor() = default; | ||
|
|
||
| public: | ||
| /// \brief Returns the name of the class for type traits. | ||
| /// \return The name of the class. | ||
| static const char* name() { return "ExtendedTensor"; } | ||
|
|
||
| int64_t numel() const override; | ||
|
|
||
| const DDim& dims() const override; | ||
|
|
||
| const Place& place() const override; | ||
|
|
||
| DataType dtype() const override; | ||
|
|
||
| DataLayout layout() const override; | ||
|
|
||
| bool valid() const override; | ||
|
|
||
| bool initialized() const override; | ||
|
|
||
| void* AllocateFrom(Allocator* allocator, | ||
| DataType dtype, | ||
| size_t requested_size = 0) override; | ||
| }; | ||
|
|
||
| } // namespace phi | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -142,6 +142,7 @@ enum class AttributeType { | |
| DATA_TYPE, | ||
| DATA_LAYOUT, | ||
| PLACE, | ||
| STRING_PTR, | ||
|
||
| }; | ||
|
|
||
| struct AttributeArgDef { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
| #include "paddle/phi/core/kernel_utils.h" | ||
| #include "paddle/phi/core/macros.h" | ||
| #include "paddle/phi/core/type_defs.h" | ||
| #include "paddle/phi/core/vocab.h" | ||
|
||
|
|
||
| namespace phi { | ||
|
|
||
|
|
@@ -100,6 +101,12 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { | |
| default_tensor_layout, | ||
| default_key.dtype(), | ||
| arg_type); | ||
| } else if (arg_type == | ||
| std::type_index(typeid(const std::vector<const Vocab*>&))) { | ||
| args_def->AppendInput(default_key.backend(), | ||
| default_tensor_layout, | ||
| default_key.dtype(), | ||
| arg_type); | ||
| } else if (arg_type == std::type_index(typeid( | ||
| const std::vector<const SelectedRows*>&))) { | ||
| args_def->AppendInput(default_key.backend(), | ||
|
|
@@ -203,6 +210,8 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { | |
| args_def->AppendAttribute(AttributeType::FLOAT64); | ||
| } else if (arg_type == std::type_index(typeid(std::string))) { | ||
| args_def->AppendAttribute(AttributeType::STRING); | ||
| } else if (arg_type == std::type_index(typeid(std::string*))) { | ||
| args_def->AppendAttribute(AttributeType::STRING_PTR); | ||
| } else if (arg_type == | ||
| std::type_index(typeid(const std::vector<bool>&))) { | ||
| args_def->AppendAttribute(AttributeType::BOOLS); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些函数可以声明为纯虚函数吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是故意设计成这样的,减少自定义输入输出类型继承后的不合理代码