Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/infrt/dialect/init_infrt_dialects.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "paddle/infrt/dialect/infrt/infrt_dialect.h"
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/pten/infrt_pten_kernel.h"
#include "paddle/infrt/dialect/pten/infrt_pten_tensor.h"
#include "paddle/infrt/dialect/pten/pten_base.h"
#include "paddle/infrt/dialect/tensor_shape.h"
Expand All @@ -34,6 +35,7 @@ void registerCinnDialects(mlir::DialectRegistry &registry) { // NOLINT
mlir::pd::PaddleDialect,
#ifdef INFRT_WITH_PTEN
pten::PTENDenseTensorDialect,
pten::PTENKernelDialect,
pten::PTENDialect
#endif
>();
Expand Down
4 changes: 3 additions & 1 deletion paddle/infrt/dialect/pten/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ add_mlir_dialect(infrt_pten_kernel pten_kernel)

gather_srcs(infrt_src SRCS
pten_base.cc infrt_pten_tensor.cc
infrt_pten_tensor.cc)
infrt_pten_tensor.cc
infrt_pten_kernel.cc
)
38 changes: 38 additions & 0 deletions paddle/infrt/dialect/pten/infrt_pten_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/infrt/dialect/pten/infrt_pten_kernel.h"

#include <mlir/IR/BuiltinTypes.h>

#include "paddle/infrt/dialect/dense_tensor.h"
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pten/infrt_pten_kernel.h.inc"
#include "paddle/infrt/dialect/pten/infrt_pten_kernelDialect.cpp.inc"

namespace infrt {
namespace pten {

void PTENKernelDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/pten/infrt_pten_kernel.cpp.inc"
>();
}

} // namespace pten
} // namespace infrt

#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pten/infrt_pten_kernel.cpp.inc" // NOLINT
39 changes: 39 additions & 0 deletions paddle/infrt/dialect/pten/infrt_pten_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <mlir/Dialect/Traits.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Matchers.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>

#include "paddle/infrt/dialect/pten/infrt_pten_kernelTypes.h.inc"
#include "paddle/infrt/dialect/pten/pten_base.h"

#include "paddle/infrt/dialect/pten/infrt_pten_kernelDialect.h.inc"

namespace infrt {
namespace pten {} // namespace pten
} // namespace infrt
14 changes: 6 additions & 8 deletions paddle/infrt/dialect/pten/infrt_pten_kernel.td
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#ifndef PTEN_KERNEL
#ifdef PTEN_KERNEL
#else
#define PTEN_KERNEL

include "paddle/infrt/dialect/pten/infrt_pten_tensor.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/infrt_base.td"
include "paddle/infrt/dialect/pten/infrt_pten_base.td"

def PTEN_KernelDialect : Dialect {
let name = "pten_kernel";
Expand All @@ -17,10 +21,4 @@ def PTEN_KernelDialect : Dialect {
class PDT_Kernel<string mnemonic, list<OpTrait> traits = []> : Op<PTEN_KernelDialect, mnemonic, !listconcat(traits, [IsolatedFromAbove])> {
}

def FakeKernelOp : PDT_Kernel<"pten.matmul.host.fp32"> {
let arguments = (ins CPU_Context:$dev_ctx, TensorType:$x, TensorType:$y, BoolAttr:$transpose_x, BoolAttr:$transpose_y);
let results = (outs TensorType:$output);
}

#endif

5 changes: 5 additions & 0 deletions paddle/infrt/dialect/pten/infrt_pten_tensor.td
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,9 @@ def PDT_FillDenseTensorOp_f32 : FillDenseTensorOp<F32ArrayAttr, "f32">;
def PDT_CreateAllocatorOp_cpu : CreateCPUAllocatorOp;
def PDT_CreateContextOp_cpu : CreateCPUContextOp;

def FakeKernelOp : PDT_Op<"fake_pten_kernel"> {
let arguments = (ins CPU_Context:$dev_ctx, TensorType:$x, TensorType:$y, BoolAttr:$transpose_x);
let results = (outs TensorType:$output);
}

#endif
1 change: 1 addition & 0 deletions paddle/infrt/host_context/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ gather_srcs(infrt_src SRCS

cc_test_tiny(test_infrt_host_context_value SRCS value_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_kernel_utils SRCS kernel_utils_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_kernel_frame SRCS kernel_frame_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_kernel_registry SRCS kernel_registry_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_op_executable SRCS op_executable_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_core_runtime SRCS core_runtime_test.cc DEPS infrt ${MLIR_IR_LIBS})
Expand Down
31 changes: 31 additions & 0 deletions paddle/infrt/host_context/kernel_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/infrt/host_context/kernel_frame.h"
#include <sstream>

#include <memory>

Expand All @@ -25,5 +26,35 @@ std::ostream& operator<<(std::ostream& os, const KernelFrame& frame) {
return os;
}

#ifndef NDEBUG
std::string KernelFrame::DumpArgTypes() const {
std::stringstream ss;
for (auto* value : GetValues(0, GetNumElements())) {
if (value->is_type<bool>()) {
ss << "bool,";
} else if (value->is_type<tensor::DenseHostTensor>()) {
ss << "DenseHostTensor,";
} else if (value->is_type<float>()) {
ss << "float,";
} else if (value->is_type<float>()) {
ss << "int,";
} else if (value->is_type<pten::DenseTensor>()) {
ss << "pten::DenseTensor,";
} else if (value->is_type<pten::MetaTensor>()) {
ss << "pten::MetaTensor,";
} else if (value->is_type<::pten::CPUContext>()) {
ss << "pten::CPUContext,";
} else if (value->is_type<host_context::None>()) {
ss << "none,";
} else if (value->is_type<backends::CpuPtenContext>()) {
ss << "CpuPtenContext,";
} else {
ss << "unk,";
}
}
return ss.str();
}
#endif

} // namespace host_context
} // namespace infrt
74 changes: 44 additions & 30 deletions paddle/infrt/host_context/kernel_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,24 @@ namespace host_context {
class KernelFrame {
public:
int GetNumArgs() const { return num_arguments_; }
int GetNumResults() const { return num_results_ == -1 ? 0 : num_results_; }
int GetNumAttributes() const {
return value_or_attrs_.size() - num_arguments_ -
(num_results_ == -1 ? 0 : num_results_);
int GetNumResults() const {
return value_or_attrs_.size() - num_arguments_ - GetNumAttributes();
}
int GetNumAttributes() const { return num_attrs_ == -1 ? 0 : num_attrs_; }

//! Get something at a specific position \p index. The element might be an
//! argument, an attribute or a result.
template <typename T>
T& GetElementAt(int index) {
CHECK_LT(index, GetNumArgs() + GetNumAttributes() + GetNumResults());
CHECK_LT(static_cast<size_t>(index), GetNumElements());
return value_or_attrs_[index]->template get_or_default<T>();
}

Value* GetElementAt(int index) {
CHECK_LT(static_cast<size_t>(index), GetNumElements());
return value_or_attrs_[index];
}

// Get number of elements, either input, attributes or results.
size_t GetNumElements() const { return value_or_attrs_.size(); }

Expand All @@ -70,18 +74,21 @@ class KernelFrame {
}

Value* GetAttributeAt(int idx) {
CHECK_NE(num_results_, -1)
<< "Must call SetNumResults before GetAttributeAt";
CHECK_LT(idx,
static_cast<int>(value_or_attrs_.size() - num_arguments_ -
num_results_));
return value_or_attrs_[num_arguments_ + num_results_ + idx];
// CHECK_NE(num_results_, -1)
//<< "Must call SetNumResults before GetAttributeAt";
CHECK_LT(idx, GetNumAttributes());
return value_or_attrs_[num_arguments_ + idx];
}

void AddAttribute(Value* v) {
CHECK_NE(num_results_, -1)
<< "Must call SetNumResults before calling AddAttribute";
CHECK_EQ(num_results_, -1)
<< "Must call SetNumResults after calling AddAttribute";
value_or_attrs_.emplace_back(v);
if (num_attrs_ == -1) num_attrs_ = 0;
num_attrs_++;

CHECK_EQ(value_or_attrs_.size(),
static_cast<size_t>(num_arguments_ + num_attrs_));
}

template <typename T, typename... Args>
Expand All @@ -96,35 +103,41 @@ class KernelFrame {

template <typename T>
void SetResultAt(int index, T&& value) {
CHECK_LT(index, num_results_) << "Invalid result index";
CHECK(value_or_attrs_[num_arguments_ + index]);
value_or_attrs_[num_arguments_ + index]->set(std::move(value));
CHECK_LT(index, GetNumResults()) << "Invalid result index";
CHECK(value_or_attrs_[num_arguments_ + GetNumAttributes() + index]);
value_or_attrs_[num_arguments_ + GetNumAttributes() + index]->set(
std::move(value));
}

llvm::ArrayRef<Value*> GetResults() const {
return GetValues(num_arguments_, num_results_);
return GetValues(num_arguments_ + GetNumAttributes(), num_results_);
}
llvm::MutableArrayRef<Value*> GetResults() {
return GetMutableValues(num_arguments_, num_results_);
return GetMutableValues(num_arguments_ + GetNumAttributes(), num_results_);
}

llvm::ArrayRef<Value*> GetValues(size_t from, size_t length) const {
CHECK_LE(static_cast<int>(from + length), num_arguments_ + num_results_);
CHECK_LE(from + length, GetNumElements());
if (length == 0) return {};

return llvm::makeArrayRef(&value_or_attrs_[from], length);
}

llvm::MutableArrayRef<Value*> GetMutableValues(size_t from, size_t length) {
CHECK_LE(static_cast<int>(from + length), num_arguments_ + num_results_);
CHECK_LE(from + length, GetNumElements());
if (length == 0) return {};
return llvm::makeMutableArrayRef(&value_or_attrs_[from], length);
}

#ifndef NDEBUG
std::string DumpArgTypes() const;
#endif

bool IsEmpty() const { return value_or_attrs_.empty(); }

protected:
int num_arguments_{};
int num_attrs_{-1};
int num_results_{-1};

llvm::SmallVector<Value*, 8> value_or_attrs_;
Expand All @@ -136,44 +149,45 @@ class KernelFrameBuilder : public KernelFrame {
public:
void AddArgument(Value* value) {
CHECK(value);
CHECK_EQ(num_results_, -1)
<< "Should call AddArgument before calling SetNumResults";
CHECK_EQ(num_attrs_, -1)
<< "Should call AddArgument before calling SetAttributes";
value_or_attrs_.push_back(value);
++num_arguments_;
}

void SetResults(llvm::ArrayRef<Value*> values) {
CHECK_EQ(num_arguments_, static_cast<int>(value_or_attrs_.size()));
CHECK_EQ(num_results_, -1);
CHECK_EQ(num_arguments_ + GetNumAttributes(),
static_cast<int>(value_or_attrs_.size()));
for (Value* x : values) {
value_or_attrs_.push_back(x);
}
num_results_ = values.size();
}

void SetNumResults(size_t n) {
CHECK_EQ(num_arguments_, static_cast<int>(value_or_attrs_.size()));
CHECK_EQ(num_results_, -1);
num_results_ = n;
CHECK_EQ(num_arguments_ + GetNumAttributes(),
static_cast<int>(value_or_attrs_.size()));
for (size_t i = 0; i < n; i++) {
value_or_attrs_.emplace_back(new Value);
}
}

void SetResultAt(int result_id, Value* value) {
CHECK_EQ(static_cast<int>(value_or_attrs_.size()),
num_arguments_ + num_results_)
num_arguments_ + GetNumAttributes() + num_results_)
<< "Call SetNumResults first";
CHECK_LT(result_id + num_arguments_,
CHECK_LT(result_id + num_arguments_ + GetNumAttributes(),
static_cast<int>(value_or_attrs_.size()));
CHECK(value);
value_or_attrs_[num_arguments_ + result_id]->set(value);
value_or_attrs_[num_arguments_ + GetNumAttributes() + result_id]->set(
value);
}

void Reset() {
value_or_attrs_.clear();
num_arguments_ = 0;
num_results_ = -1;
num_attrs_ = -1;
}
};

Expand Down
Loading