|
14 | 14 |
|
15 | 15 | #include "paddle/fluid/framework/ir/xpu/pass_utils.h" |
16 | 16 | #include "paddle/fluid/platform/enforce.h" |
| 17 | +#include "paddle/phi/kernels/cast_kernel.h" |
17 | 18 | #include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" |
18 | 19 |
|
19 | 20 | namespace paddle { |
@@ -123,11 +124,68 @@ template size_t HashTensor<int16_t>(const phi::DenseTensor& in); |
123 | 124 | template size_t HashTensor<float>(const phi::DenseTensor& in); |
124 | 125 | template size_t HashTensor<int8_t>(const phi::DenseTensor& in); |
125 | 126 |
|
| 127 | +template <> |
| 128 | +size_t HashTensor<float16>(const phi::DenseTensor& in) { |
| 129 | + phi::DenseTensor dst_tensor; |
| 130 | + auto* cpu_ctx = static_cast<phi::CPUContext*>( |
| 131 | + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); |
| 132 | + dst_tensor.Resize(in.dims()); |
| 133 | + dst_tensor.set_type(phi::DataType::FLOAT32); |
| 134 | + dst_tensor.set_layout(in.layout()); |
| 135 | + phi::CastKernel<float16>(*cpu_ctx, in, phi::DataType::FLOAT32, &dst_tensor); |
| 136 | + return HashTensor<float>(dst_tensor); |
| 137 | +} |
| 138 | + |
126 | 139 | std::string GetPrefixWithoutHash(const std::string& name) { |
127 | 140 | std::size_t found = name.find("_#"); |
128 | 141 | return found == std::string::npos ? name : name.substr(0, found); |
129 | 142 | } |
130 | 143 |
|
| 144 | +void ConvertFromFp32ToFp16(phi::DenseTensor* weight, |
| 145 | + phi::DenseTensor* weight_max, |
| 146 | + bool transpose) { |
| 147 | + // Convert fp16 to fp32 |
| 148 | + phi::DenseTensor weight_fp32; |
| 149 | + CastToFp32(weight, &weight_fp32); |
| 150 | + |
| 151 | + if (transpose) { // (k, n) -> (n, k) |
| 152 | + Transpose2D(&weight_fp32); |
| 153 | + } |
| 154 | + |
| 155 | + auto FindMaxAbs = [](const float* data, int len) { |
| 156 | + float max_f = 0.0f; |
| 157 | + for (int i = 0; i < len; ++i) { |
| 158 | + float max = std::abs(data[i]); |
| 159 | + if (max > max_f) { |
| 160 | + max_f = max; |
| 161 | + } |
| 162 | + } |
| 163 | + return max_f; |
| 164 | + }; |
| 165 | + |
| 166 | + auto* cpu_ctx = static_cast<phi::CPUContext*>( |
| 167 | + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); |
| 168 | + // Convert to fp16 |
| 169 | + phi::DenseTensor weight_fp16; |
| 170 | + CastToFp16(&weight_fp32, &weight_fp16); |
| 171 | + // Find max |
| 172 | + int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1); |
| 173 | + int size = weight_fp32.numel(); |
| 174 | + float max_val = FindMaxAbs(weight_fp32.data<float>(), size); |
| 175 | + std::vector<float> max_vec(max_ptr_size, max_val); |
| 176 | + weight_max->set_type(phi::DataType::FLOAT32); |
| 177 | + weight_max->Resize({max_ptr_size}); |
| 178 | + memcpy(cpu_ctx->Alloc<float>(weight_max), |
| 179 | + max_vec.data(), |
| 180 | + max_ptr_size * sizeof(float)); |
| 181 | + weight->clear(); |
| 182 | + weight->set_type(phi::DataType::FLOAT16); |
| 183 | + weight->Resize({size}); |
| 184 | + memcpy(cpu_ctx->Alloc<float16>(weight), |
| 185 | + weight_fp16.data<float16>(), |
| 186 | + size * sizeof(float16)); |
| 187 | +} |
| 188 | + |
131 | 189 | template <typename Tcpu, typename Txpu> |
132 | 190 | void PrepareWeight(Graph* graph, |
133 | 191 | Scope* scope, |
@@ -268,6 +326,18 @@ template void PrepareWeight<float, float>( |
268 | 326 | const std::vector<float>& weight_scales, |
269 | 327 | bool per_channel_quant = false); |
270 | 328 |
|
| 329 | +template void PrepareWeight<float, float16>( |
| 330 | + Graph* graph, |
| 331 | + Scope* scope, |
| 332 | + BlockDesc* block, |
| 333 | + Node* weight, |
| 334 | + Node** dst_weight, |
| 335 | + Node** dst_weight_max, |
| 336 | + Node** dst_scale_max, |
| 337 | + bool transpose, |
| 338 | + const std::vector<float>& weight_scales, |
| 339 | + bool per_channel_quant = false); |
| 340 | + |
271 | 341 | template void PrepareWeight<float, int16_t>( |
272 | 342 | Graph* graph, |
273 | 343 | Scope* scope, |
|
0 commit comments