Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions paddle/pir/core/ir_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,52 @@ void IrPrinter::AddValueAlias(Value v, const std::string& alias) {
aliases_[key] = alias;
}

class CustomPrinter : public IrPrinter {
public:
explicit CustomPrinter(std::ostream& os, const PrintHooks& hooks)
: IrPrinter(os), hooks_(hooks) {}
void PrintType(Type type) override {
if (hooks_.type_print_hook) {
hooks_.type_print_hook(type, *this);
} else {
IrPrinter::PrintType(type);
}
}

void PrintAttribute(Attribute attr) override {
if (hooks_.attribute_print_hook) {
hooks_.attribute_print_hook(attr, *this);
} else {
IrPrinter::PrintAttribute(attr);
}
}

void PrintOperation(Operation* op) override {
if (hooks_.op_print_hook) {
hooks_.op_print_hook(op, *this);
} else {
IrPrinter::PrintOperation(op);
}
}

void PrintValue(Value v) override {
if (hooks_.value_print_hook) {
hooks_.value_print_hook(v, *this);
} else {
IrPrinter::PrintValue(v);
}
}

private:
const PrintHooks hooks_;
};

std::ostream& operator<<(std::ostream& os, const CustomPrintHelper& p) {
CustomPrinter printer(os, p.hooks_);
printer.PrintProgram(&p.prog_);
return os;
}

void Program::Print(std::ostream& os) const {
IrPrinter printer(os);
printer.PrintProgram(this);
Expand Down
41 changes: 35 additions & 6 deletions paddle/pir/core/ir_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class BasicIrPrinter {
public:
explicit BasicIrPrinter(std::ostream& os) : os(os) {}

void PrintType(Type type);
virtual void PrintType(Type type);

void PrintAttribute(Attribute attr);
virtual void PrintAttribute(Attribute attr);

public:
std::ostream& os;
Expand All @@ -49,7 +49,7 @@ class IR_API IrPrinter : public BasicIrPrinter {
void PrintProgram(const Program* program);

/// @brief dispatch to custom printer function or PrintGeneralOperation
void PrintOperation(Operation* op);
virtual void PrintOperation(Operation* op);
/// @brief print operation itself without its regions
void PrintOperationWithNoRegion(Operation* op);
/// @brief print operation and its regions
Expand All @@ -58,7 +58,7 @@ class IR_API IrPrinter : public BasicIrPrinter {
void PrintRegion(const Region& Region);
void PrintBlock(const Block& block);

void PrintValue(Value v);
virtual void PrintValue(Value v);

void PrintOpResult(Operation* op);

Expand All @@ -74,14 +74,43 @@ class IR_API IrPrinter : public BasicIrPrinter {

void AddIndentation();
void DecreaseIndentation();
std::string indentation() { return cur_indentation_; }
const std::string& indentation() const { return cur_indentation_; }

private:
size_t cur_result_number_{0};
size_t cur_block_argument_number_{0};
size_t cur_indentation_level_{0};
std::string cur_indentation_;
std::unordered_map<const void*, std::string> aliases_;
};

using ValuePrintHook =
std::function<void(Value value, IrPrinter& printer)>; // NOLINT
using TypePrintHook =
std::function<void(Type type, IrPrinter& printer)>; // NOLINT
using AttributePrintHook =
std::function<void(Attribute attr, IrPrinter& printer)>; // NOLINT
using OpPrintHook =
std::function<void(Operation* op, IrPrinter& printer)>; // NOLINT

struct IR_API PrintHooks {
ValuePrintHook value_print_hook{nullptr};
TypePrintHook type_print_hook{nullptr};
AttributePrintHook attribute_print_hook{nullptr};
OpPrintHook op_print_hook{nullptr};
};

class IR_API CustomPrintHelper {
public:
explicit CustomPrintHelper(const Program& program, const PrintHooks& hooks)
: hooks_(hooks), prog_(program) {}
friend IR_API std::ostream& operator<<(std::ostream& os,
const CustomPrintHelper& p);

private:
const PrintHooks& hooks_;
const Program& prog_;
};

IR_API std::ostream& operator<<(std::ostream& os, const CustomPrintHelper& p);

} // namespace pir
1 change: 1 addition & 0 deletions test/cpp/pir/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ paddle_test(ir_builder_test SRCS ir_builder_test.cc)
paddle_test(ir_program_test SRCS ir_program_test.cc)
paddle_test(ir_infershape_test SRCS ir_infershape_test.cc)
paddle_test(scalar_attribute_test SRCS scalar_attribute_test.cc)
paddle_test(ir_printer_test SRCS ir_printer_test.cc DEPS test_dialect)

file(
DOWNLOAD https://paddle-ci.gz.bcebos.com/ir_translator_test/resnet50_main.prog
Expand Down
85 changes: 85 additions & 0 deletions test/cpp/pir/core/ir_printer_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>
#include <sstream>

#include "paddle/pir/core/dialect.h"
#include "paddle/pir/core/ir_context.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/core/op_base.h"
#include "paddle/pir/core/program.h"
#include "test/cpp/pir/tools/test_dialect.h"
#include "test/cpp/pir/tools/test_op.h"
#include "test/cpp/pir/tools/test_pir_utils.h"

TEST(printer_test, custom_hooks) {
pir::IrContext* ctx = pir::IrContext::Instance();
pir::Dialect* test_dialect = ctx->GetOrRegisterDialect<test::TestDialect>();
EXPECT_EQ(test_dialect != nullptr, true);

pir::OpInfo op1_info = ctx->GetRegisteredOpInfo(test::Operation1::name());
pir::OpInfo op2_info = ctx->GetRegisteredOpInfo(test::Operation2::name());

pir::Operation* op1 = pir::Operation::Create(
{},
test::CreateAttributeMap({"op1_attr1", "op1_attr2"},
{"op1_attr1", "op1_attr2"}),
{pir::Float32Type::get(ctx)},
op1_info);
pir::Operation* op2 = pir::Operation::Create(
{op1->result(0)}, {}, {pir::Float32Type::get(ctx)}, op2_info);

pir::Program program(ctx);
program.block()->push_back(op1);
program.block()->push_back(op2);

pir::PrintHooks hooks;
// this one retains old printing and adds new info
hooks.value_print_hook = [](pir::Value v, pir::IrPrinter& printer) {
printer.IrPrinter::PrintValue(v);
printer.os << " [extra info]";
};
// this one overrides old printing
hooks.op_print_hook = [](pir::Operation* op, pir::IrPrinter& printer) {
printer.PrintOpResult(op);
printer.os << " :=";

printer.os << " \"" << op->name() << "\"";
printer.PrintOpOperands(op);
printer.PrintAttributeMap(op);
printer.os << " :";
printer.PrintOpReturnType(op);
printer.os << "\n";
};

hooks.attribute_print_hook = [](pir::Attribute attr,
pir::IrPrinter& printer) {
printer.os << "[PlaceHolder]";
};
hooks.type_print_hook = [](pir::Type type, pir::IrPrinter& printer) {
printer.os << "[" << type << "]";
};

std::stringstream ss;

ss << pir::CustomPrintHelper{program, hooks};
EXPECT_EQ(
ss.str(),
"{\n"
"(%0 [extra info]) := \"test.operation1\" () "
"{op1_attr1:[PlaceHolder],op1_attr2:[PlaceHolder]} :[f32]\n"
"(%1 [extra info]) := \"test.operation2\" (%0 [extra info]) {} :[f32]\n"
"}\n");
}