@@ -22,22 +22,25 @@ const std::vector<int64_t>& get_dims_mapping(
2222 const phi::distributed::ArgDistAttr& dist_attr) {
2323 EXPECT_TRUE (
2424 paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr));
25- const auto & tensor_attr = paddle::get<0 >(dist_attr);
25+ const auto & tensor_attr =
26+ PADDLE_GET_CONST (phi::distributed::TensorDistAttr, dist_attr);
2627 return tensor_attr.dims_mapping ();
2728}
2829
2930bool is_partial (const phi::distributed::ArgDistAttr& dist_attr) {
3031 EXPECT_TRUE (
3132 paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr));
32- const auto & tensor_attr = paddle::get<0 >(dist_attr);
33+ const auto & tensor_attr =
34+ PADDLE_GET_CONST (phi::distributed::TensorDistAttr, dist_attr);
3335 return tensor_attr.is_partial ();
3436}
3537
3638const std::set<int64_t > get_partial_dims (
3739 const phi::distributed::ArgDistAttr& dist_attr) {
3840 EXPECT_TRUE (
3941 paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr));
40- const auto & tensor_attr = paddle::get<0 >(dist_attr);
42+ const auto & tensor_attr =
43+ PADDLE_GET_CONST (phi::distributed::TensorDistAttr, dist_attr);
4144 return tensor_attr.partial_dims ();
4245}
4346
@@ -74,7 +77,8 @@ void check_empty_dist_attr(const phi::distributed::ArgDistAttr& dist_attr,
7477 EXPECT_TRUE (
7578 paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr))
7679 << line;
77- EXPECT_EQ (paddle::get<0 >(dist_attr), phi::distributed::TensorDistAttr ());
80+ EXPECT_EQ (PADDLE_GET_CONST (phi::distributed::TensorDistAttr, dist_attr),
81+ phi::distributed::TensorDistAttr ());
7882}
7983
8084void check_partial_dims (const phi::distributed::ArgDistAttr& dist_attr,
@@ -89,23 +93,23 @@ void check_partial_dims(const phi::distributed::ArgDistAttr& dist_attr,
8993void clean_partial_status (phi::distributed::ArgDistAttr* dist_attr) {
9094 EXPECT_TRUE (
9195 paddle::holds_alternative<phi::distributed::TensorDistAttr>(*dist_attr));
92- auto & tensor_attr = paddle::get< 0 >( *dist_attr);
96+ auto & tensor_attr = PADDLE_GET (phi::distributed::TensorDistAttr, *dist_attr);
9397 tensor_attr.clean_partial_status ();
9498}
9599
96100void clean_partial_dims (phi::distributed::ArgDistAttr* dist_attr,
97101 std::vector<int64_t > dims) {
98102 EXPECT_TRUE (
99103 paddle::holds_alternative<phi::distributed::TensorDistAttr>(*dist_attr));
100- auto & tensor_attr = paddle::get< 0 >( *dist_attr);
104+ auto & tensor_attr = PADDLE_GET (phi::distributed::TensorDistAttr, *dist_attr);
101105 tensor_attr.clean_partial_dims (dims);
102106}
103107
104108void set_partial_status (phi::distributed::ArgDistAttr* dist_attr,
105109 std::vector<int64_t > dims) {
106110 EXPECT_TRUE (
107111 paddle::holds_alternative<phi::distributed::TensorDistAttr>(*dist_attr));
108- auto & tensor_attr = paddle::get< 0 >( *dist_attr);
112+ auto & tensor_attr = PADDLE_GET (phi::distributed::TensorDistAttr, *dist_attr);
109113 tensor_attr.set_partial_status (dims);
110114}
111115
0 commit comments