Skip to content

Commit a4db267

Browse files
[slice] Add AMP Logic in ApplyGetitem (#74727)
1 parent b5992db commit a4db267

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

paddle/fluid/pybind/slice_utils.h

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,43 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
820820
indices_int64.push_back(indice);
821821
}
822822

823+
// AMP Logic
824+
if (egr::Controller::Instance().GetAMPLevel() !=
825+
paddle::imperative::AmpLevel::O0) {
826+
auto op_name = phi::TransToFluidOpName("index_elementwise_get");
827+
paddle::small_vector<std::vector<paddle::Tensor>,
828+
egr::kSlotSmallVectorSize>
829+
amp_tensors_vector = {{self_tensor}};
830+
831+
auto amp_dst_dtype =
832+
paddle::imperative::GetAmpDestDtype(op_name, amp_tensors_vector);
833+
834+
auto new_self_tensor = paddle::imperative::AmpAutoCast(
835+
"self_tensor", self_tensor, amp_dst_dtype, op_name);
836+
auto new_tensor = paddle::imperative::AmpAutoCast(
837+
"tensor", tensor, amp_dst_dtype, op_name);
838+
839+
{
840+
paddle::imperative::AutoCastGuard guard(
841+
egr::Controller::Instance().GetCurrentAmpAttrs(),
842+
paddle::imperative::AmpLevel::O0);
843+
844+
AdvancedIndex ad = AdvancedIndex(new_tensor, indices_int64);
845+
const bool is_combined = false;
846+
const bool accumulate = false;
847+
848+
return index_elementwise_get_ad_func(new_self_tensor,
849+
ad.indices,
850+
ad.src_sizes,
851+
ad.src_strides,
852+
ad.indexed_sizes,
853+
ad.indexed_strides,
854+
slice_offset,
855+
accumulate,
856+
is_combined);
857+
}
858+
}
859+
823860
AdvancedIndex ad = AdvancedIndex(tensor, indices_int64);
824861
const bool is_combined = false;
825862
const bool accumulate = false;
@@ -1287,6 +1324,45 @@ static void ApplyGetitem(const int index_size,
12871324
transed_tensor,
12881325
&transed_index_int64);
12891326

1327+
// AMP Logic
1328+
if (egr::Controller::Instance().GetAMPLevel() !=
1329+
paddle::imperative::AmpLevel::O0) {
1330+
auto op_name = phi::TransToFluidOpName("index_elementwise_get");
1331+
paddle::small_vector<std::vector<paddle::Tensor>,
1332+
egr::kSlotSmallVectorSize>
1333+
amp_tensors_vector = {{*self_tensor}};
1334+
1335+
auto amp_dst_dtype =
1336+
paddle::imperative::GetAmpDestDtype(op_name, amp_tensors_vector);
1337+
1338+
auto new_self_tensor = paddle::imperative::AmpAutoCast(
1339+
"self_tensor", *self_tensor, amp_dst_dtype, op_name);
1340+
auto new_transed_tensor = paddle::imperative::AmpAutoCast(
1341+
"transed_tensor", *transed_tensor, amp_dst_dtype, op_name);
1342+
1343+
{
1344+
paddle::imperative::AutoCastGuard guard(
1345+
egr::Controller::Instance().GetCurrentAmpAttrs(),
1346+
paddle::imperative::AmpLevel::O0);
1347+
1348+
AdvancedIndex ad =
1349+
AdvancedIndex(new_transed_tensor, transed_index_int64);
1350+
1351+
const bool is_combined = (index_size == 1) ? false : true;
1352+
const bool accumulate = true;
1353+
*out = index_elementwise_get_ad_func(new_self_tensor,
1354+
ad.indices,
1355+
ad.src_sizes,
1356+
ad.src_strides,
1357+
ad.indexed_sizes,
1358+
ad.indexed_strides,
1359+
slice_offset,
1360+
accumulate,
1361+
is_combined);
1362+
}
1363+
return;
1364+
}
1365+
12901366
AdvancedIndex ad = AdvancedIndex(*transed_tensor, transed_index_int64);
12911367
// is_combined:
12921368
// Distinguishes between regular indexing (single index) and combined

0 commit comments

Comments
 (0)