|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
| 2 | +
|
| 3 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | + you may not use this file except in compliance with the License. |
| 5 | + You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | + Unless required by applicable law or agreed to in writing, software |
| 10 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | + See the License for the specific language governing permissions and |
| 13 | + limitations under the License. */ |
| 14 | + |
1 | 15 | #include "paddle/framework/op_registry.h" |
2 | 16 | #include <gtest/gtest.h> |
3 | 17 |
|
@@ -182,3 +196,71 @@ TEST(OperatorRegistrar, Test) { |
182 | 196 | using namespace paddle::framework; |
183 | 197 | OperatorRegistrar<CosineOpComplete, CosineOpProtoAndCheckerMaker> reg("cos"); |
184 | 198 | } |
| 199 | + |
| 200 | +namespace paddle { |
| 201 | +namespace framework { |
| 202 | + |
| 203 | +class OpKernelTestMaker : public OpProtoAndCheckerMaker { |
| 204 | + public: |
| 205 | + OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker) |
| 206 | + : OpProtoAndCheckerMaker(proto, op_checker) { |
| 207 | + AddComment("NoGradOp, same input output. no Grad"); |
| 208 | + } |
| 209 | +}; |
| 210 | + |
| 211 | +class OpWithKernelTest : public OperatorWithKernel { |
| 212 | + public: |
| 213 | + using OperatorWithKernel::OperatorWithKernel; |
| 214 | + |
| 215 | + protected: |
| 216 | + void InferShape(InferShapeContext* ctx) const override {} |
| 217 | + |
| 218 | + framework::OpKernelType GetActualKernelType( |
| 219 | + const framework::ExecutionContext& ctx) const override { |
| 220 | + return framework::OpKernelType(proto::DataType::FP32, ctx.device_context()); |
| 221 | + } |
| 222 | +}; |
| 223 | + |
| 224 | +template <typename DeviceContext, typename T> |
| 225 | +class OpKernelTest : public paddle::framework::OpKernel<T> { |
| 226 | + public: |
| 227 | + void Compute(const paddle::framework::ExecutionContext& ctx) const {} |
| 228 | +}; |
| 229 | + |
| 230 | +} // namespace framework |
| 231 | +} // namespace paddle |
| 232 | + |
| 233 | +REGISTER_OP_WITHOUT_GRADIENT(op_with_kernel, |
| 234 | + paddle::framework::OpWithKernelTest, |
| 235 | + paddle::framework::OpKernelTestMaker); |
| 236 | +REGISTER_OP_CPU_KERNEL( |
| 237 | + op_with_kernel, |
| 238 | + paddle::framework::OpKernelTest<paddle::platform::CPUDeviceContext, float>); |
| 239 | + |
| 240 | +REGISTER_OP_CUDA_KERNEL(op_with_kernel, |
| 241 | + paddle::framework::OpKernelTest< |
| 242 | + paddle::platform::CUDADeviceContext, float>); |
| 243 | + |
| 244 | +TEST(OperatorRegistrar, CPU) { |
| 245 | + paddle::framework::proto::OpDesc op_desc; |
| 246 | + paddle::platform::CPUPlace cpu_place; |
| 247 | + paddle::framework::Scope scope; |
| 248 | + |
| 249 | + op_desc.set_type("op_with_kernel"); |
| 250 | + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); |
| 251 | + |
| 252 | + op->Run(scope, cpu_place); |
| 253 | +} |
| 254 | + |
| 255 | +#ifdef PADDLE_WITH_CUDA |
| 256 | +TEST(OperatorRegistrar, CUDA) { |
| 257 | + paddle::framework::proto::OpDesc op_desc; |
| 258 | + paddle::platform::CUDAPlace cuda_place(0); |
| 259 | + paddle::framework::Scope scope; |
| 260 | + |
| 261 | + op_desc.set_type("op_with_kernel"); |
| 262 | + auto op = paddle::framework::OpRegistry::CreateOp(op_desc); |
| 263 | + |
| 264 | + op->Run(scope, cuda_place); |
| 265 | +} |
| 266 | +#endif |
0 commit comments