Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions rfcs/APIs/20230918_api_design_for_put_along_axis.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# paddle.put_along_axis 设计文档
# paddle.put_along_axis API 增强设计文档

| API名称 | paddle.put_along_axis |
| ------------ | -------------------------------------- |
Expand Down Expand Up @@ -225,18 +225,19 @@ if (op == SCATTER_GATHER_OP::REDUCE_MEAN) {

其中 put_along_axis_ 是 put_along_axis 的 inplace 版本。

- `arr (Tensor) - 输入的 Tensor 作为目标矩阵,数据类型为:float32、float64。`
- `indices (Tensor) - 索引矩阵,包含沿轴提取 1d 切片的下标,必须和 arr 矩阵有相同的维度,需要能够 broadcast 与 arr 矩阵对齐,数据类型为:int、int64。`
- `value (float)- 需要插入的值,形状和维度需要能够被 broadcast 与 indices 矩阵匹配,数据类型为:float32、float64。`
- `arr (Tensor)` - 输入的 Tensor 作为目标矩阵,数据类型为:float32、float64, int32, int64。 GPU 额外支持float16和bfloat16。
- `indices (Tensor)` - 索引矩阵,包含沿轴提取 1d 切片的下标,必须和 arr 矩阵有相同的维度,需要能够 broadcast 与 arr 矩阵对齐,数据类型为:int、int64。
- `value (float)` - 需要插入的值,形状和维度需要能够被 broadcast 与 indices 矩阵匹配,数据类型同 arr。
- `axis (int) - 指定沿着哪个维度获取对应的值,数据类型为:int。`
- `reduce (str,可选) - 归约操作类型,默认为 assign,可选为 add, mul,amax, amin, mean。不同的规约操作插入值 value 对于输入矩阵 arr 会有不同的行为,如为 assgin 则覆盖输入矩阵,add 则累加至输入矩阵,mul 则累乘至输入矩阵,amax 则取最大至输入矩阵, amin 则取最小至输入矩阵, mean 则取平均至输入矩阵。`
- `reduce (str,可选) - 归约操作类型,默认为 assign,可选为 add, mul/multiply,amax, amin, mean。不同的规约操作插入值 value 对于输入矩阵 arr 会有不同的行为,如为 assgin 则覆盖输入矩阵,add 则累加至输入矩阵,mul/multiply 则累乘至输入矩阵,amax 则取最大至输入矩阵, amin 则取最小至输入矩阵, mean 则取平均至输入矩阵。`
- `include_self (bool,可选)` - arr 张量中的元素是否包含在规约中。默认值 include_self = True.


相比于 torch.scatter_reduce 主要差异点为:

1. reduce 新增 max, min, mean 等规约方式。 torch 不支持 assgin。
2. torch 支持 include_self 配置。
1. reduce 目前只支持 add、assign、mul/multiply。
2. 模型支持 include_self=True的实现,不支持False的实现,且没有对应的参数。
3. 反向梯度计算也存在差异。

## 底层OP设计

Expand Down Expand Up @@ -359,12 +360,14 @@ reduce=sum 的计算逻辑和 mean 类似。

## API实现方案

在底层 paddle\phi\kernels\funcs\gather_scatter_functor.cc 增加对应的归约算子,在 python\paddle\tensor\manipulation.py 中修改下 docstring 即可
在底层 paddle/phi/kernels/cpu/put_along_axis_kernel.cc 和 paddle/phi/kernels/gpu/put_along_axis_kernel.cu 增加对应的归约算子。
在底层 paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc 和 paddle/phi/kernels/gpu/put_along_axis_grad_kernel.cu 增加对应的归约梯度算子。
在 python\paddle\tensor\manipulation.py 中修改下 api 和 docstring 即可。

# 六、测试和验收的考量

测试考虑的 case 如下:
- 增加 reduce 分别为 'min'、'max' 和 'mean' 时的单测.
- 增加 reduce 分别为 'amin'、'amax' 和 'mean' 时的单测.
- 增加 include_self 的单侧。
- 验证反向梯度是否正确。

Expand Down