@@ -17,6 +17,7 @@ limitations under the License. */
1717#include < vector>
1818#include " gtest/gtest.h"
1919#include " paddle/fluid/framework/op_registry.h"
20+ #include " paddle/fluid/platform/bfloat16.h"
2021#include " paddle/fluid/platform/float16.h"
2122
2223USE_CPU_ONLY_OP (save_combine);
@@ -76,33 +77,34 @@ void CheckValues(T* expect, U* actual, const paddle::framework::LoD& expect_lod,
7677
7778// Here, we create 4 LoDTensors and use save_combine_op to first save these
7879// in a single file. Then, we use load_combine_op to load these sequentially
79- TEST (SaveLoadCombineOp, CPU) {
80+ template <typename T, typename U>
81+ void SaveLoadCombineOp () {
8082 paddle::framework::Scope scope;
8183 paddle::platform::CPUPlace place;
8284
8385 std::vector<int > lod1 = {0 , 1 , 2 , 3 , 10 };
8486 int numel1 = 100 ;
8587 paddle::framework::LoD expect_lod1;
86- int * expect1 = CreateForSaveCombineOp<int , int >(10 , 10 , lod1, " test_var1" ,
87- place, &scope, &expect_lod1);
88+ T * expect1 = CreateForSaveCombineOp<T, U >(10 , 10 , lod1, " test_var1" , place ,
89+ &scope, &expect_lod1);
8890
8991 std::vector<int > lod2 = {0 , 2 , 5 , 10 };
9092 int numel2 = 200 ;
9193 paddle::framework::LoD expect_lod2;
92- int * expect2 = CreateForSaveCombineOp<int , int >(10 , 20 , lod2, " test_var2" ,
93- place, &scope, &expect_lod2);
94+ T * expect2 = CreateForSaveCombineOp<T, U >(10 , 20 , lod2, " test_var2" , place ,
95+ &scope, &expect_lod2);
9496
9597 std::vector<int > lod3 = {0 , 2 , 3 , 20 };
9698 int numel3 = 4000 ;
9799 paddle::framework::LoD expect_lod3;
98- int * expect3 = CreateForSaveCombineOp<int , int >(20 , 200 , lod3, " test_var3" ,
99- place, &scope, &expect_lod3);
100+ T * expect3 = CreateForSaveCombineOp<T, U >(20 , 200 , lod3, " test_var3" , place ,
101+ &scope, &expect_lod3);
100102
101103 std::vector<int > lod4 = {0 , 1 , 20 };
102104 int numel4 = 1000 ;
103105 paddle::framework::LoD expect_lod4;
104- int * expect4 = CreateForSaveCombineOp<int , int >(20 , 50 , lod4, " test_var4" ,
105- place, &scope, &expect_lod4);
106+ T * expect4 = CreateForSaveCombineOp<T, U >(20 , 50 , lod4, " test_var4" , place ,
107+ &scope, &expect_lod4);
106108
107109 // Set attributes
108110 std::string filename = " check_tensor.ls" ;
@@ -128,15 +130,21 @@ TEST(SaveLoadCombineOp, CPU) {
128130 load_combine_op->Run (scope, place);
129131
130132 paddle::framework::LoD actual_lod1, actual_lod2, actual_lod3, actual_lod4;
131- int * actual1 = GetValuesAfterLoadCombineOp<int >(target1, scope, &actual_lod1);
132- int * actual2 = GetValuesAfterLoadCombineOp<int >(target2, scope, &actual_lod2);
133- int * actual3 = GetValuesAfterLoadCombineOp<int >(target3, scope, &actual_lod3);
134- int * actual4 = GetValuesAfterLoadCombineOp<int >(target4, scope, &actual_lod4);
135-
136- CheckValues<int , int >(expect1, actual1, expect_lod1, actual_lod1, numel1);
137- CheckValues<int , int >(expect2, actual2, expect_lod2, actual_lod2, numel2);
138- CheckValues<int , int >(expect3, actual3, expect_lod3, actual_lod3, numel3);
139- CheckValues<int , int >(expect4, actual4, expect_lod4, actual_lod4, numel4);
133+ U* actual1 = GetValuesAfterLoadCombineOp<U>(target1, scope, &actual_lod1);
134+ U* actual2 = GetValuesAfterLoadCombineOp<U>(target2, scope, &actual_lod2);
135+ U* actual3 = GetValuesAfterLoadCombineOp<U>(target3, scope, &actual_lod3);
136+ U* actual4 = GetValuesAfterLoadCombineOp<U>(target4, scope, &actual_lod4);
137+
138+ CheckValues<T, U>(expect1, actual1, expect_lod1, actual_lod1, numel1);
139+ CheckValues<T, U>(expect2, actual2, expect_lod2, actual_lod2, numel2);
140+ CheckValues<T, U>(expect3, actual3, expect_lod3, actual_lod3, numel3);
141+ CheckValues<T, U>(expect4, actual4, expect_lod4, actual_lod4, numel4);
142+ }
143+
144+ TEST (SaveLoadCombineOp, CPU) { SaveLoadCombineOp<int , int >(); }
145+
146+ TEST (SaveLoadCombineBF16Op, CPU) {
147+ SaveLoadCombineOp<paddle::platform::bfloat16, paddle::platform::bfloat16>();
140148}
141149
142150// FP16 version of SaveLoadCombineOp Test, only altering the saving aspect
0 commit comments