|
| 1 | +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/fluid/distributed/collective/MPITools.h" |
| 16 | +#include "paddle/fluid/distributed/collective/Common.h" |
| 17 | +#include "paddle/fluid/distributed/collective/Types.h" |
| 18 | + |
| 19 | +namespace paddle { |
| 20 | +namespace distributed { |
| 21 | +namespace mpi { |
| 22 | + |
| 23 | +MPI_Op ToMPIType(ReduceOp reduction) { |
| 24 | + static const std::map<ReduceOp, MPI_Op> red_type = { |
| 25 | + {ReduceOp::MIN, MPI_MIN}, |
| 26 | + {ReduceOp::MAX, MPI_MAX}, |
| 27 | + {ReduceOp::SUM, MPI_SUM}, |
| 28 | + {ReduceOp::PRODUCT, MPI_PROD}, |
| 29 | + }; |
| 30 | + auto it = red_type.find(reduction); |
| 31 | + PADDLE_ENFORCE_EQ(it != red_type.end(), |
| 32 | + true, |
| 33 | + platform::errors::InvalidArgument( |
| 34 | + "Invalid mpi reduction. Must be MPI_MIN | MPI_MAX | " |
| 35 | + "MPI_PROD | MPI_SUM.")); |
| 36 | + return it->second; |
| 37 | +} |
| 38 | + |
| 39 | +// NOTE: MPI dose not support CUDA aware now. |
| 40 | +bool CheckMpiCudaAware() { return false; } |
| 41 | + |
| 42 | +void CheckValidInputs(const std::vector<phi::DenseTensor>& tensors) { |
| 43 | + PADDLE_ENFORCE_EQ( |
| 44 | + tensors.size() == 1, |
| 45 | + true, |
| 46 | + platform::errors::InvalidArgument("the inputs size of MPI must be 1!")); |
| 47 | + |
| 48 | + PADDLE_ENFORCE_EQ(CheckTensorsInCudaPlace(tensors) && !CheckMpiCudaAware(), |
| 49 | + false, |
| 50 | + platform::errors::InvalidArgument( |
| 51 | + "Found CUDA Tensor. But CUDA-aware MPI not support!")); |
| 52 | +} |
| 53 | + |
| 54 | +} // namespace mpi |
| 55 | +} // namespace distributed |
| 56 | +} // namespace paddle |
0 commit comments