-
Notifications
You must be signed in to change notification settings - Fork 5.9k
optimization of index_select op backward #32955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimization of index_select op backward #32955
Conversation
|
Thanks for your contribution! |
|
Sorry to inform you that d121f02's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在Conversation的Comment区域要描述本次PR的目的,PR修改前后性能变化情况等信息。
| using Tensor = framework::Tensor; | ||
| using LoDTensor = framework::LoDTensor; | ||
| using DDim = framework::DDim; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
必要的空行有助于阅读代码,不要删除。
| auto input_dim_size = input_dim.size(); | ||
| auto output_dim = x_grad->dims(); | ||
| std::vector<T> out_vec(x_grad->numel(), 0); | ||
| std::memset(out_data, 0.0, x_grad->numel() * sizeof(T)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以用SetConstant,另外初始化为0的部分,放到L196后的for循环里面,每次初始化一部分,对cache是不是友好些?
| }; | ||
|
|
||
| template <typename T> | ||
| struct IndexSelectAdd< |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一段仿函数的意义感觉不大,感觉除了浮点之外采用的是下述通用形式。
template <typename platform::cpu_isa_t isa, typename T, class Enable = void>
struct IndexSelectAdd {
void operator()(int n, const T* src, T* dst) {
for (int i = 0; i < n; i++) {
dst[i] += src[i];
}
}
};
| auto& out_grad = out_grad_var->Get<LoDTensor>(); | ||
| auto* x_grad = x_grad_var->GetMutable<framework::LoDTensor>(); | ||
| int dim = context.Attr<int>("dim"); | ||
| if (dim < 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line212 - Line219 可以改成:
auto *x_grad = ctx.Input<framework::LoDTensor>("X");
auto *index = ctx.Input<framework::LoDTensor>("Index");
auto *out_grad = ctx.Output<framework::LoDTensor>("Out");
| }; | ||
|
|
||
| template <typename T, typename IndexT = int> | ||
| #if ((!defined __NVCC__) && (!defined __HIPCC__)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的宏是否还有必要
| void operator()(const framework::ExecutionContext& ctx, int slice_size, | ||
| const T* src_pointer, const T* p_pointer, T* dist_pointer) { | ||
| auto blas = math::GetBlas<DeviceContext, T>(ctx); | ||
| blas.VADD(slice_size, src_pointer, p_pointer, dist_pointer); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用的blas的时候,可以测一下不同OMP设置情况下的加速比。
Xreki
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| template <typename DeviceContext, typename T, typename IndexT = int> | ||
| void IndexSelectGradInner(const framework::ExecutionContext& context, | ||
| const LoDTensor& out_grad, const LoDTensor& index, | ||
| const LoDTensor* out_grad, const LoDTensor* index, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要修改参数的类型,不用修改的输入用const Tensor&类型。
PR types
Performance optimization
PR changes
OPs
Describe
Optimization of index_select op backward, the optimized data as follows:

Compared with the origin and pytorch, optimization measures have been improved.