|
| 1 | +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/fluid/inference/api/paddle_infer_contrib.h" |
| 16 | +#include "paddle/fluid/framework/scope.h" |
| 17 | +#include "paddle/fluid/memory/memcpy.h" |
| 18 | +#include "paddle/fluid/platform/device_context.h" |
| 19 | +#include "paddle/fluid/platform/enforce.h" |
| 20 | +#include "paddle/fluid/platform/float16.h" |
| 21 | + |
| 22 | +namespace paddle_infer { |
| 23 | +namespace contrib { |
| 24 | + |
| 25 | +using paddle::PaddleDType; |
| 26 | + |
| 27 | +void* TensorUtils::CudaMallocPinnedMemory(size_t size) { |
| 28 | +#if defined(PADDLE_WITH_CUDA) |
| 29 | + void* ptr = nullptr; |
| 30 | + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMallocHost(&ptr, size)); |
| 31 | + return ptr; |
| 32 | +#else |
| 33 | + return nullptr; |
| 34 | +#endif |
| 35 | +} |
| 36 | + |
| 37 | +void TensorUtils::CudaFreePinnedMemory(void* ptr) { |
| 38 | +#if defined(PADDLE_WITH_CUDA) |
| 39 | + PADDLE_ENFORCE_CUDA_SUCCESS(cudaFreeHost(ptr)); |
| 40 | +#endif |
| 41 | +} |
| 42 | + |
| 43 | +void TensorUtils::CopyTensorImpl(Tensor* p_dst, const Tensor& src, |
| 44 | + void* exec_stream, CallbackFunc cb, |
| 45 | + void* cb_params) { |
| 46 | + Tensor& dst = *p_dst; |
| 47 | + dst.Reshape(src.shape()); |
| 48 | + PADDLE_ENFORCE( |
| 49 | + src.place() == PlaceType::kCPU || src.place() == PlaceType::kGPU, |
| 50 | + paddle::platform::errors::InvalidArgument( |
| 51 | + "CopyTensor only support PlaceType kCPU/kGPU now.")); |
| 52 | + PADDLE_ENFORCE( |
| 53 | + dst.place() == PlaceType::kCPU || dst.place() == PlaceType::kGPU, |
| 54 | + paddle::platform::errors::InvalidArgument( |
| 55 | + "CopyTensor only support PlaceType kCPU/kGPU now.")); |
| 56 | + // copy to cpu, gpu => cpu or cpu => cpu |
| 57 | + if (dst.place() == PlaceType::kCPU) { |
| 58 | + switch (src.type()) { |
| 59 | + case PaddleDType::INT32: |
| 60 | + src.CopyToCpuImpl(dst.mutable_data<int32_t>(PlaceType::kCPU), |
| 61 | + exec_stream, cb, cb_params); |
| 62 | + break; |
| 63 | + case PaddleDType::INT64: |
| 64 | + src.CopyToCpuImpl(dst.mutable_data<int64_t>(PlaceType::kCPU), |
| 65 | + exec_stream, cb, cb_params); |
| 66 | + break; |
| 67 | + case PaddleDType::FLOAT32: |
| 68 | + src.CopyToCpuImpl(dst.mutable_data<float>(PlaceType::kCPU), exec_stream, |
| 69 | + cb, cb_params); |
| 70 | + break; |
| 71 | + case PaddleDType::UINT8: |
| 72 | + src.CopyToCpuImpl(dst.mutable_data<uint8_t>(PlaceType::kCPU), |
| 73 | + exec_stream, cb, cb_params); |
| 74 | + break; |
| 75 | + case PaddleDType::INT8: |
| 76 | + src.CopyToCpuImpl(dst.mutable_data<int8_t>(PlaceType::kCPU), |
| 77 | + exec_stream, cb, cb_params); |
| 78 | + break; |
| 79 | + case PaddleDType::FLOAT16: |
| 80 | + src.CopyToCpuImpl( |
| 81 | + dst.mutable_data<paddle::platform::float16>(PlaceType::kCPU), |
| 82 | + exec_stream, cb, cb_params); |
| 83 | + break; |
| 84 | + default: |
| 85 | + PADDLE_THROW(paddle::platform::errors::Unimplemented( |
| 86 | + "Only INT32, INT64, UINT8, INT8, FLOAT16 and " |
| 87 | + "FLOAT32 is supported in Tensor. Others not implements")); |
| 88 | + } |
| 89 | + // gpu => gpu or cpu => gpu |
| 90 | + } else { |
| 91 | +#if defined(PADDLE_WITH_CUDA) |
| 92 | + void* dst_data = nullptr; |
| 93 | + void* src_data = nullptr; |
| 94 | + size_t data_len = 0; |
| 95 | + int data_size = 0; |
| 96 | + PlaceType src_place; |
| 97 | + switch (src.type()) { |
| 98 | + case PaddleDType::INT32: |
| 99 | + dst_data = |
| 100 | + static_cast<void*>(dst.mutable_data<int32_t>(PlaceType::kGPU)); |
| 101 | + src_data = |
| 102 | + static_cast<void*>(src.data<int32_t>(&src_place, &data_size)); |
| 103 | + data_len = data_size * sizeof(int32_t); |
| 104 | + break; |
| 105 | + case PaddleDType::INT64: |
| 106 | + dst_data = |
| 107 | + static_cast<void*>(dst.mutable_data<int64_t>(PlaceType::kGPU)); |
| 108 | + src_data = |
| 109 | + static_cast<void*>(src.data<int64_t>(&src_place, &data_size)); |
| 110 | + data_len = data_size * sizeof(int64_t); |
| 111 | + break; |
| 112 | + case PaddleDType::FLOAT32: |
| 113 | + dst_data = static_cast<void*>(dst.mutable_data<float>(PlaceType::kGPU)); |
| 114 | + src_data = static_cast<void*>(src.data<float>(&src_place, &data_size)); |
| 115 | + data_len = data_size * sizeof(float); |
| 116 | + break; |
| 117 | + case PaddleDType::UINT8: |
| 118 | + dst_data = |
| 119 | + static_cast<void*>(dst.mutable_data<uint8_t>(PlaceType::kGPU)); |
| 120 | + src_data = |
| 121 | + static_cast<void*>(src.data<uint8_t>(&src_place, &data_size)); |
| 122 | + data_len = data_size * sizeof(uint8_t); |
| 123 | + break; |
| 124 | + case PaddleDType::INT8: |
| 125 | + dst_data = |
| 126 | + static_cast<void*>(dst.mutable_data<int8_t>(PlaceType::kGPU)); |
| 127 | + src_data = static_cast<void*>(src.data<int8_t>(&src_place, &data_size)); |
| 128 | + data_len = data_size * sizeof(int8_t); |
| 129 | + break; |
| 130 | + case PaddleDType::FLOAT16: |
| 131 | + dst_data = static_cast<void*>( |
| 132 | + dst.mutable_data<paddle::platform::float16>(PlaceType::kGPU)); |
| 133 | + src_data = static_cast<void*>( |
| 134 | + src.data<paddle::platform::float16>(&src_place, &data_size)); |
| 135 | + data_len = data_size * 2; |
| 136 | + break; |
| 137 | + default: |
| 138 | + PADDLE_THROW(paddle::platform::errors::Unimplemented( |
| 139 | + "Only INT32, INT64, UINT8, INT8, FLOAT16 and " |
| 140 | + "FLOAT32 is supported in Tensor. Others not implements")); |
| 141 | + } |
| 142 | + |
| 143 | + paddle::platform::DeviceContextPool& pool = |
| 144 | + paddle::platform::DeviceContextPool::Instance(); |
| 145 | + paddle::platform::CUDAPlace gpu_place(dst.device_); |
| 146 | + auto* dev_ctx = static_cast<const paddle::platform::CUDADeviceContext*>( |
| 147 | + pool.Get(gpu_place)); |
| 148 | + |
| 149 | + if (src.place() == PlaceType::kCPU) { |
| 150 | + paddle::memory::Copy(gpu_place, static_cast<void*>(dst_data), |
| 151 | + paddle::platform::CPUPlace(), src_data, data_len, |
| 152 | + dev_ctx->stream()); |
| 153 | + } else { |
| 154 | + paddle::memory::Copy(gpu_place, static_cast<void*>(dst_data), |
| 155 | + paddle::platform::CUDAPlace(), src_data, data_len, |
| 156 | + dev_ctx->stream()); |
| 157 | + } |
| 158 | + |
| 159 | + if (nullptr != exec_stream) { |
| 160 | + *(static_cast<cudaStream_t*>(exec_stream)) = dev_ctx->stream(); |
| 161 | + } else if (cb) { |
| 162 | + cudaLaunchHostFunc(dev_ctx->stream(), cb, cb_params); |
| 163 | + } else { |
| 164 | + cudaStreamSynchronize(dev_ctx->stream()); |
| 165 | + } |
| 166 | +#else |
| 167 | + PADDLE_THROW(paddle::platform::errors::Unavailable( |
| 168 | + "Can not copy tensor to GPU CUDA place because paddle is not compiled " |
| 169 | + "with CUDA.")); |
| 170 | +#endif |
| 171 | + } |
| 172 | + return; |
| 173 | +} |
| 174 | + |
| 175 | +void TensorUtils::CopyTensor(Tensor* p_dst, const Tensor& src) { |
| 176 | + CopyTensorImpl(p_dst, src, nullptr, nullptr, nullptr); |
| 177 | +} |
| 178 | + |
| 179 | +void TensorUtils::CopyTensorAsync(Tensor* p_dst, const Tensor& src, |
| 180 | + void* exec_stream) { |
| 181 | + CopyTensorImpl(p_dst, src, exec_stream, nullptr, nullptr); |
| 182 | +} |
| 183 | + |
| 184 | +void TensorUtils::CopyTensorAsync(Tensor* p_dst, const Tensor& src, |
| 185 | + CallbackFunc cb, void* cb_params) { |
| 186 | + CopyTensorImpl(p_dst, src, nullptr, cb, cb_params); |
| 187 | +} |
| 188 | + |
| 189 | +} // namespace contrib |
| 190 | +} // namespace paddle_infer |
0 commit comments