Skip to content

Commit 42a0947

Browse files
authored
[Move selected_rows PR #5] VisitDataType use Pten::DataType (PaddlePaddle#39236)
* Added selected_rows and rw_lock to pten * Renamed the unit test target to fix CI * Removed Class SelectedRows in Fluid, changed include/cmake relationship, use pten::SelectedRows in Fluid * Remove rw_lock.h,rw_lock_test.cc in fluid * Use pten::RWLock and pten::AutoRDLock, fix CI * Use pten::SelectedRows * Use pten::SelectedRows * Fix to pass NPU CI * Selected_Rows inherits from TensorBase * Use pten::SelectedRows, to pass NPU CI * To fix NPU CI * To fix NPU CI again * Use paddle/pten/core/enforce and polish code * Use pten::DataType instead of using proto_type * Move part of data_type to pten * Polish Code
1 parent 452bcbe commit 42a0947

File tree

2 files changed

+72
-11
lines changed

2 files changed

+72
-11
lines changed

paddle/pten/core/selected_rows.cc

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/pten/core/selected_rows.h"
16-
17-
// See Note [ Why still include the fluid headers? ]
18-
#include "paddle/fluid/framework/data_type.h"
16+
#include "paddle/pten/core/utils/data_type.h"
1917

2018
namespace pten {
2119

@@ -191,16 +189,16 @@ void SelectedRows::Get(const pten::DenseTensor& ids,
191189
int64_t index = AutoGrownIndex(id, auto_grown, is_test);
192190
if (index < 0) {
193191
VLOG(5) << "id " << id << " not in the table, return 0";
194-
paddle::framework::VisitDataType(
195-
value_->type(),
192+
pten::VisitDataType(
193+
value_->dtype(),
196194
TensorFillVisitor(value, i * value_width, value_width, 0.0));
197195
} else {
198-
paddle::framework::VisitDataType(value_->type(),
199-
TensorCopyVisitor(value,
200-
i * value_width,
201-
*value_.get(),
202-
index * value_width,
203-
value_width));
196+
pten::VisitDataType(value_->dtype(),
197+
TensorCopyVisitor(value,
198+
i * value_width,
199+
*value_.get(),
200+
index * value_width,
201+
value_width));
204202
}
205203
}
206204
}

paddle/pten/core/utils/data_type.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <iostream>
17+
#include <string>
18+
#include <typeindex>
19+
20+
#include "paddle/pten/common/data_type.h"
21+
#include "paddle/pten/core/enforce.h"
22+
#include "paddle/pten/kernels/funcs/eigen/extensions.h"
23+
24+
namespace pten {
25+
26+
#define _PtenForEachDataTypeHelper_(callback, cpp_type, data_type) \
27+
callback(cpp_type, data_type);
28+
29+
#define _PtenForEachDataType_(callback) \
30+
_PtenForEachDataTypeHelper_(callback, float, DataType::FLOAT32); \
31+
_PtenForEachDataTypeHelper_( \
32+
callback, ::paddle::platform::float16, DataType::FLOAT16); \
33+
_PtenForEachDataTypeHelper_( \
34+
callback, ::paddle::platform::bfloat16, DataType::BFLOAT16); \
35+
_PtenForEachDataTypeHelper_(callback, double, DataType::FLOAT64); \
36+
_PtenForEachDataTypeHelper_(callback, int, DataType::INT32); \
37+
_PtenForEachDataTypeHelper_(callback, int64_t, DataType::INT64); \
38+
_PtenForEachDataTypeHelper_(callback, bool, DataType::BOOL); \
39+
_PtenForEachDataTypeHelper_(callback, uint8_t, DataType::UINT8); \
40+
_PtenForEachDataTypeHelper_(callback, int16_t, DataType::INT16); \
41+
_PtenForEachDataTypeHelper_(callback, int8_t, DataType::INT8); \
42+
_PtenForEachDataTypeHelper_( \
43+
callback, ::paddle::platform::complex<float>, DataType::COMPLEX64); \
44+
_PtenForEachDataTypeHelper_( \
45+
callback, ::paddle::platform::complex<double>, DataType::COMPLEX128);
46+
47+
template <typename Visitor>
48+
inline void VisitDataType(pten::DataType type, Visitor visitor) {
49+
#define PtenVisitDataTypeCallback(cpp_type, data_type) \
50+
do { \
51+
if (type == data_type) { \
52+
visitor.template apply<cpp_type>(); \
53+
return; \
54+
} \
55+
} while (0)
56+
57+
_PtenForEachDataType_(PtenVisitDataTypeCallback);
58+
#undef PtenVisitDataTypeCallback
59+
PADDLE_THROW(pten::errors::Unimplemented(
60+
"Not supported proto::VarType::Type(%d) as data type.",
61+
static_cast<int>(type)));
62+
}
63+
} // namespace pten

0 commit comments

Comments
 (0)