Skip to content

Commit 24392e6

Browse files
[Prim] Add index_put_grad for static decomposition (#73747)
* support index_put_grad in static prim * fix * fix typo * disable cinn in UT
1 parent 94e2150 commit 24392e6

File tree

7 files changed

+207
-0
lines changed

7 files changed

+207
-0
lines changed

paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
'unsqueeze_grad',
139139
'p_norm_grad',
140140
'masked_fill_grad',
141+
'index_put_grad',
141142
'index_add_grad',
142143
]
143144

paddle/fluid/primitive/codegen/decomp_vjp_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@
162162
'swiglu_grad',
163163
'p_norm_grad',
164164
'masked_fill_grad',
165+
'index_put_grad',
165166
'index_add_grad',
166167
] # custom vjp list of composite op
167168

paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,67 @@ void roll_grad(const Tensor& x,
440440
}
441441
}
442442

443+
template <typename T>
444+
void index_put_grad(const Tensor& x,
445+
const std::vector<Tensor>& indices,
446+
const Tensor& value,
447+
const Tensor& out_grad,
448+
const bool accumulate,
449+
Tensor* x_grad,
450+
Tensor* value_grad) {
451+
if (x_grad) {
452+
if (accumulate) {
453+
by_pass<T>(out_grad, x_grad);
454+
} else {
455+
Tensor x_grad_tmp;
456+
if (has_dynamic_shape(x.shape()) ||
457+
std::any_of(
458+
indices.cbegin(),
459+
indices.cend(),
460+
[](const Tensor& t) { return has_dynamic_shape(t.shape()); }) ||
461+
has_dynamic_shape(out_grad.shape())) {
462+
x_grad_tmp = index_put<T>(
463+
out_grad,
464+
indices,
465+
backend::full_with_tensor<T>(
466+
shape64<T>(value), 0, out_grad.dtype(), out_grad.place()));
467+
} else {
468+
x_grad_tmp = index_put<T>(out_grad,
469+
indices,
470+
full<T>(common::vectorize(value.dims()),
471+
0,
472+
out_grad.dtype(),
473+
out_grad.place()));
474+
}
475+
set_output<T>(x_grad_tmp, x_grad);
476+
}
477+
}
478+
479+
if (value_grad) {
480+
std::vector<Tensor> indices_vec;
481+
482+
if (has_dynamic_shape(x.shape()) ||
483+
std::any_of(
484+
indices.cbegin(),
485+
indices.cend(),
486+
[](const Tensor& t) { return has_dynamic_shape(t.shape()); }) ||
487+
has_dynamic_shape(out_grad.shape())) {
488+
for (int i = 0; i < indices.size(); ++i) {
489+
indices_vec.push_back(backend::unsqueeze<T>(
490+
indices[i], full<T>({1}, -1, DataType::INT64, indices[i].place())));
491+
}
492+
} else {
493+
for (int i = 0; i < indices.size(); ++i) {
494+
indices_vec.push_back(unsqueeze<T>(indices[i], {-1}));
495+
}
496+
}
497+
498+
Tensor stacked_indices = concat<T>(indices_vec, -1);
499+
Tensor value_grad_tmp = gather_nd<T>(out_grad, stacked_indices);
500+
set_output<T>(value_grad_tmp, value_grad);
501+
}
502+
}
503+
443504
template <typename T>
444505
void transpose_grad(const Tensor& grad_out,
445506
const std::vector<int>& perm,

paddle/fluid/primitive/primitive/primitive.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
- gather_nd
7373
- scatter
7474
- scatter_nd
75+
- index_put
7576
- scatter_nd_add
7677
- put_along_axis
7778
- take_along_axis

paddle/phi/ops/yaml/backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,6 +1732,7 @@
17321732
data_transform :
17331733
skip_transform : indices
17341734
backward : index_put_double_grad
1735+
no_need_buffer: x, value
17351736

17361737
- backward_op : index_sample_grad
17371738
forward : index_sample (Tensor x, Tensor index) -> Tensor(out)

python/paddle/autograd/backward_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
"pd_op.unsqueeze",
108108
"pd_op.where",
109109
"pd_op.p_norm",
110+
"pd_op.index_put",
110111
"pd_op.index_add",
111112
"pd_op.elu",
112113
"pd_op.masked_fill",

test/legacy_test/test_index_put_op.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,5 +1086,146 @@ def test_dygraph_forward(self):
10861086
)
10871087

10881088

1089+
class TestIndexPutPrim(unittest.TestCase):
1090+
def __int__(self):
1091+
self().__init__()
1092+
1093+
def test_prim(self):
1094+
try:
1095+
paddle.framework.core._set_prim_all_enabled(True)
1096+
for accumulate in [False, True]:
1097+
for x_shape, indices_shape, value_shape in [
1098+
([16], [10], [10]),
1099+
([16, 16], [20, 2], [20]),
1100+
([12, 13, 14], [88, 1], [88, 13, 14]),
1101+
([12, 13, 14], [88, 2], [88, 14]),
1102+
([12, 13, 14], [88, 3], [88]),
1103+
([12, 13, 14], [12 * 13 * 14, 3], [12 * 13 * 14]),
1104+
]:
1105+
n_indices = indices_shape[0]
1106+
index_dim_size = (
1107+
indices_shape[1] if len(indices_shape) > 1 else 1
1108+
)
1109+
1110+
x_np = np.random.randn(*x_shape)
1111+
indices_np = tuple(
1112+
[
1113+
np.random.randint(
1114+
-x_shape[i], x_shape[i], [n_indices]
1115+
)
1116+
for i in range(max(index_dim_size, 1))
1117+
]
1118+
)
1119+
value_np = np.random.randn(*value_shape).astype("float32")
1120+
1121+
# run paddle
1122+
x_pd = paddle.to_tensor(
1123+
x_np.copy(),
1124+
"float32",
1125+
stop_gradient=False,
1126+
)
1127+
indices_pd = tuple(
1128+
[
1129+
paddle.to_tensor(
1130+
indice.copy(),
1131+
"int64",
1132+
stop_gradient=True,
1133+
)
1134+
for indice in indices_np
1135+
]
1136+
)
1137+
value_pd = paddle.to_tensor(
1138+
value_np.copy(),
1139+
"float32",
1140+
stop_gradient=False,
1141+
)
1142+
1143+
out_pd = paddle.index_put(
1144+
x_pd, indices_pd, value_pd, accumulate=accumulate
1145+
)
1146+
# out_pd = paddle.tanh(out_pd) #
1147+
dout_np = np.random.randn(*out_pd.shape)
1148+
1149+
dout_pd = paddle.to_tensor(
1150+
dout_np.copy(),
1151+
"float32",
1152+
stop_gradient=False,
1153+
)
1154+
dout_pd.stop_gradient = False
1155+
1156+
if accumulate:
1157+
1158+
def compute_dx_dv(x, indices, v, dy, accumulate=True):
1159+
y = paddle.index_put(x, indices, v, True)
1160+
return paddle.grad(y, [x, v], dy, create_graph=True)
1161+
1162+
else:
1163+
1164+
def compute_dx_dv(x, indices, v, dy, accumulate=False):
1165+
y = paddle.index_put(x, indices, v, False)
1166+
return paddle.grad(y, [x, v], dy, create_graph=True)
1167+
1168+
# eager
1169+
dx_ref, dv_ref = compute_dx_dv(
1170+
x_pd, indices_pd, value_pd, dout_pd
1171+
)
1172+
1173+
# static dynamic shape
1174+
st_func1 = paddle.jit.to_static(
1175+
compute_dx_dv,
1176+
input_spec=[
1177+
paddle.static.InputSpec(
1178+
shape=[-1, -1], dtype='float32'
1179+
),
1180+
tuple(
1181+
paddle.static.InputSpec(
1182+
shape=[-1], dtype='int64'
1183+
)
1184+
for _ in range(len(indices_pd))
1185+
),
1186+
paddle.static.InputSpec(
1187+
shape=[-1, -1], dtype='float32'
1188+
),
1189+
paddle.static.InputSpec(
1190+
shape=[-1, -1], dtype='float32'
1191+
),
1192+
],
1193+
full_graph=True,
1194+
backend=None,
1195+
)
1196+
dx_1, dv_1 = st_func1(x_pd, indices_pd, value_pd, dout_pd)
1197+
1198+
# static fixed shape
1199+
st_func2 = paddle.jit.to_static(
1200+
compute_dx_dv,
1201+
full_graph=True,
1202+
backend=None,
1203+
)
1204+
dx_2, dv_2 = st_func2(x_pd, indices_pd, value_pd, dout_pd)
1205+
1206+
np.testing.assert_allclose(
1207+
dx_1.numpy(),
1208+
dx_ref.numpy(),
1209+
err_msg=f"accumulate={accumulate}\nx_np:\n{x_np}\nindices_np:\n{indices_np}\nvalue_np:\n{value_np}\nout_np:{out_pd.numpy()}\n",
1210+
)
1211+
np.testing.assert_allclose(
1212+
dv_1.numpy(),
1213+
dv_ref.numpy(),
1214+
err_msg=f"accumulate={accumulate}\nx_np:\n{x_np}\nindices_np:\n{indices_np}\nvalue_np:\n{value_np}\nout_np:{out_pd.numpy()}\n",
1215+
)
1216+
np.testing.assert_allclose(
1217+
dx_2.numpy(),
1218+
dx_ref.numpy(),
1219+
err_msg=f"accumulate={accumulate}\nx_np:\n{x_np}\nindices_np:\n{indices_np}\nvalue_np:\n{value_np}\nout_np:{out_pd.numpy()}\n",
1220+
)
1221+
np.testing.assert_allclose(
1222+
dv_2.numpy(),
1223+
dv_ref.numpy(),
1224+
err_msg=f"accumulate={accumulate}\nx_np:\n{x_np}\nindices_np:\n{indices_np}\nvalue_np:\n{value_np}\nout_np:{out_pd.numpy()}\n",
1225+
)
1226+
finally:
1227+
paddle.framework.core._set_prim_all_enabled(False)
1228+
1229+
10891230
if __name__ == '__main__':
10901231
unittest.main()

0 commit comments

Comments
 (0)