|
16 | 16 |
|
17 | 17 | #include "paddle/fluid/framework/ir/graph_helper.h" |
18 | 18 | #include "paddle/fluid/framework/operator.h" |
19 | | -#include "paddle/phi/common/data_type.h" |
| 19 | +#include "paddle/phi/common/bfloat16.h" |
| 20 | +#include "paddle/phi/common/float16.h" |
| 21 | +#include "paddle/phi/common/place.h" |
| 22 | +#include "paddle/phi/core/dense_tensor.h" |
| 23 | +#include "paddle/phi/core/enforce.h" |
| 24 | +#include "paddle/phi/core/errors.h" |
20 | 25 |
|
21 | 26 | namespace paddle { |
22 | 27 | namespace framework { |
@@ -620,34 +625,45 @@ void FloatToHalfPass::ConvertWeightsData() const { |
620 | 625 | for (const auto& var_name : var_names) { |
621 | 626 | if (vars_convert_to_half_.count(var_name)) { |
622 | 627 | VLOG(4) << var_name << "'s data type was convert to half"; |
623 | | -#define CONVERT_TENSOR_DTYPE(DTYPE, dtype) \ |
624 | | - half_tensor.set_type(DTYPE); \ |
625 | | - auto* half_data = half_tensor.mutable_data<dtype>(platform::CPUPlace()); \ |
626 | | - for (int64_t i = 0; i < origin_tensor->numel(); i++) { \ |
627 | | - half_data[i] = static_cast<dtype>(origin_data[i]); \ |
628 | | - } \ |
629 | | - origin_tensor->clear(); \ |
630 | | - paddle::framework::TensorCopySync( \ |
631 | | - half_tensor, platform::CPUPlace(), origin_tensor) |
632 | 628 |
|
633 | 629 | auto* var = scope->FindLocalVar(var_name); |
634 | | - |
635 | | - if (var->IsType<phi::DenseTensor>()) { |
636 | | - auto* origin_tensor = var->GetMutable<phi::DenseTensor>(); |
637 | | - phi::DenseTensor half_tensor; |
638 | | - half_tensor.Resize(origin_tensor->dims()); |
639 | | - auto* origin_data = |
640 | | - origin_tensor->mutable_data<float>(platform::CPUPlace()); |
641 | | - if (half_precision_ == phi::DataType::FLOAT16) { |
642 | | - CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::FLOAT16, |
643 | | - phi::dtype::float16); |
644 | | - } else if (half_precision_ == phi::DataType::BFLOAT16) { |
645 | | - CONVERT_TENSOR_DTYPE(paddle::experimental::DataType::BFLOAT16, |
646 | | - phi::dtype::bfloat16); |
| 630 | + CHECK_EQ(var->IsType<phi::DenseTensor>(), true); |
| 631 | + |
| 632 | + auto* origin_tensor = var->GetMutable<phi::DenseTensor>(); |
| 633 | + |
| 634 | + phi::DenseTensor half_tensor; |
| 635 | + half_tensor.Resize(origin_tensor->dims()); |
| 636 | + half_tensor.set_type(half_precision_); |
| 637 | + |
| 638 | + if (half_precision_ == phi::DataType::FLOAT16) { |
| 639 | + auto* half_data = |
| 640 | + half_tensor.mutable_data<phi::dtype::float16>(phi::CPUPlace{}); |
| 641 | + for (int64_t i = 0; i < origin_tensor->numel(); i++) { |
| 642 | + if (origin_tensor->dtype() == phi::DataType::FLOAT64) { |
| 643 | + auto* origin_data = origin_tensor->data<double>(); |
| 644 | + half_data[i] = static_cast<phi::dtype::float16>(origin_data[i]); |
| 645 | + } else if (origin_tensor->dtype() == phi::DataType::FLOAT32) { |
| 646 | + auto* origin_data = origin_tensor->data<float>(); |
| 647 | + half_data[i] = static_cast<phi::dtype::float16>(origin_data[i]); |
| 648 | + } |
| 649 | + } |
| 650 | + } else if (half_precision_ == phi::DataType::BFLOAT16) { |
| 651 | + auto* half_data = |
| 652 | + half_tensor.mutable_data<phi::dtype::bfloat16>(phi::CPUPlace{}); |
| 653 | + for (int64_t i = 0; i < origin_tensor->numel(); i++) { |
| 654 | + if (origin_tensor->dtype() == phi::DataType::FLOAT64) { |
| 655 | + auto* origin_data = origin_tensor->data<double>(); |
| 656 | + half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]); |
| 657 | + } else if (origin_tensor->dtype() == phi::DataType::FLOAT32) { |
| 658 | + auto* origin_data = origin_tensor->data<float>(); |
| 659 | + half_data[i] = static_cast<phi::dtype::bfloat16>(origin_data[i]); |
| 660 | + } |
647 | 661 | } |
648 | 662 | } |
| 663 | + origin_tensor->clear(); |
| 664 | + paddle::framework::TensorCopySync( |
| 665 | + half_tensor, phi::CPUPlace{}, origin_tensor); |
649 | 666 | } |
650 | | -#undef CONVERT_TENSOR_DTYPE |
651 | 667 | } |
652 | 668 | } |
653 | 669 |
|
|
0 commit comments