Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/ir/localize.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,20 @@ struct ChildLocalizer {
// Nothing to add.
return parent;
}
auto* block = getChildrenReplacement();
if (!hasUnreachableChild) {
block->list.push_back(parent);
block->finalize();
}
return block;
}

// Like `getReplacement`, but the result never contains the parent.
Block* getChildrenReplacement() {
auto* block = Builder(wasm).makeBlock();
block->list.set(sets);
if (hasUnreachableChild) {
// If there is an unreachable child then we do not need the parent at all,
// and we know the type is unreachable.
block->type = Type::unreachable;
} else {
// Otherwise, add the parent and finalize.
block->list.push_back(parent);
block->finalize();
}
return block;
}
Expand Down
209 changes: 151 additions & 58 deletions src/passes/Inlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "ir/element-utils.h"
#include "ir/find_all.h"
#include "ir/literal-utils.h"
#include "ir/localize.h"
#include "ir/module-utils.h"
#include "ir/names.h"
#include "ir/type-updating.h"
Expand Down Expand Up @@ -298,68 +299,119 @@ struct Updater : public PostWalker<Updater> {
Module* module;
std::map<Index, Index> localMapping;
Name returnName;
Type resultType;
bool isReturn;
Builder* builder;
PassOptions& options;

struct ReturnCallInfo {
// The original `return_call` or `return_call_indirect` or `return_call_ref`
// with its operands replaced with `local.get`s.
Expression* call;
// The branch that is serving as the "return" part of the original
// `return_call`.
Break* branch;
};

// Collect information on return_calls in the inlined body. Each will be
// turned into branches out of the original inlined body followed by
// non-return version of the original `return_call`, followed by a branch out
// to the caller. The branch labels will be filled in at the end of the walk.
std::vector<ReturnCallInfo> returnCallInfos;

Updater(PassOptions& options) : options(options) {}

void visitReturn(Return* curr) {
replaceCurrent(builder->makeBreak(returnName, curr->value));
}
// Return calls in inlined functions should only break out of the scope of
// the inlined code, not the entire function they are being inlined into. To
// achieve this, make the call a non-return call and add a break. This does
// not cause unbounded stack growth because inlining and return calling both
// avoid creating a new stack frame.
template<typename T> void handleReturnCall(T* curr, Type results) {
if (isReturn) {

template<typename T> void handleReturnCall(T* curr, Signature sig) {
if (isReturn || !curr->isReturn) {
// If the inlined callsite was already a return_call, then we can keep
// return_calls in the inlined function rather than downgrading them.
// That is, if A->B and B->C and both those calls are return_calls
// then after inlining A->B we want to now have A->C be a
// return_call.
return;
}

// Set the children to locals as necessary, then add a branch out of the
// inlined body. The branch label will be set later when we create branch
// targets for the calls.
Block* childBlock = ChildLocalizer(curr, getFunction(), *module, options)
.getChildrenReplacement();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should avoid adding temp vars when there is no need - is there a test that shows that? (I seem to see locals created all the time but I probably just missed it.)

Break* branch = builder->makeBreak(Name());
childBlock->list.push_back(branch);
childBlock->type = Type::unreachable;
replaceCurrent(childBlock);

curr->isReturn = false;
curr->type = results;
// There might still be unreachable children causing this to be unreachable.
curr->finalize();
if (results.isConcrete()) {
replaceCurrent(builder->makeBreak(returnName, curr));
} else {
replaceCurrent(builder->blockify(curr, builder->makeBreak(returnName)));
}
curr->type = sig.results;
returnCallInfos.push_back({curr, branch});
}

void visitCall(Call* curr) {
if (curr->isReturn) {
handleReturnCall(curr, module->getFunction(curr->target)->getResults());
}
handleReturnCall(curr, module->getFunction(curr->target)->getSig());
}

void visitCallIndirect(CallIndirect* curr) {
if (curr->isReturn) {
handleReturnCall(curr, curr->heapType.getSignature().results);
}
handleReturnCall(curr, curr->heapType.getSignature());
}

void visitCallRef(CallRef* curr) {
Type targetType = curr->target->type;
if (targetType.isNull()) {
// We don't know what type the call should return, but we can't leave it
// as a potentially-invalid return_call_ref, either.
replaceCurrent(getDroppedChildrenAndAppend(
curr, *module, options, Builder(*module).makeUnreachable()));
if (!targetType.isSignature()) {
// We don't know what type the call should return, but it will also never
// be reached, so we don't need to do anything here.
return;
}
if (curr->isReturn) {
handleReturnCall(curr, targetType.getHeapType().getSignature().results);
}
handleReturnCall(curr, targetType.getHeapType().getSignature());
}

void visitLocalGet(LocalGet* curr) {
curr->index = localMapping[curr->index];
}

void visitLocalSet(LocalSet* curr) {
curr->index = localMapping[curr->index];
}

void walk(Expression*& curr) {
PostWalker<Updater>::walk(curr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than these two lines, can just implement visitFunction.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no, because it's not the entire function that gets walked, just the block containing the inlined body.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks, that's what I was missing...

if (returnCallInfos.empty()) {
return;
}

Block* body = curr->cast<Block>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why must the function body be a block?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the block we create to hold the inlined contents in the parent. We always create a block in case there are any returns in the inlined contents that would need to be transformed into branches out of that block.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure you are passing the block? The inliner is crashing on this example.

(module
 (import "env" "h" (func $h (result i32)))
 (func $g (param $x i32) (result i32)
  (if (result i32) (local.get $x)
   (then
    (return_call $h)
   )
   (else
    (i32.const 17)
   )
  )
 )
 (func (export "f") (param $x i32) (result i32)
  (call $g (local.get $x))
 )
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this comment was incorrect. I fixed it though, so the latest version of this branch should be good. I just tested this example locally and it worked fine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, the fix is in #6451.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, good catch. I've moved it to the correct PR now.

auto blockNames = BranchUtils::BranchAccumulator::get(body);

for (Index i = 0; i < returnCallInfos.size(); ++i) {
auto& info = returnCallInfos[i];

// Add a block containing the previous body and a branch up to the caller.
// Give the block a name that will allow this return_call's original
// callsite to branch out of it then execute the call before returning to
// the caller.
auto name = Names::getValidName(
"__return_call", [&](Name test) { return !blockNames.count(test); }, i);
blockNames.insert(name);
info.branch->name = name;
Block* oldBody = builder->makeBlock(body->list, body->type);
body->list.clear();

if (resultType.isConcrete()) {
body->list.push_back(builder->makeBlock(
name, {builder->makeBreak(returnName, oldBody)}, Type::none));
} else {
oldBody->list.push_back(builder->makeBreak(returnName));
oldBody->name = name;
oldBody->type = Type::none;
body->list.push_back(oldBody);
}
body->list.push_back(info.call);
body->finalize(resultType);
}
}
};

// Core inlining logic. Modifies the outside function (adding locals as
Expand All @@ -376,15 +428,19 @@ static Expression* doInlining(Module* module,
Index nameHint = 0) {
Function* from = action.contents;
auto* call = (*action.callSite)->cast<Call>();

// Works for return_call, too
Type retType = module->getFunction(call->target)->getResults();

// Build the block that will contain the inlined contents.
Builder builder(*module);
auto* block = builder.makeBlock();
auto name = std::string("__inlined_func$") + from->name.toString();
if (nameHint) {
name += '$' + std::to_string(nameHint);
}
block->name = Name(name);

// In the unlikely event that the function already has a branch target with
// this name, fix that up, as otherwise we can get unexpected capture of our
// branches, that is, we could end up with this:
Expand All @@ -407,59 +463,96 @@ static Expression* doInlining(Module* module,
//
// (In this case we could use a second block and define the named block $X
// after the call's parameters, but that adds work for an extremely rare
// situation.)
// situation.) The latter case does not apply if the call is a return_call,
// because in that case the call's children do not appear inside the same
// block as the inlined body.
if (BranchUtils::hasBranchTarget(from->body, block->name) ||
BranchUtils::BranchSeeker::has(call, block->name)) {
(!call->isReturn && BranchUtils::BranchSeeker::has(call, block->name))) {
auto fromNames = BranchUtils::getBranchTargets(from->body);
auto callNames = BranchUtils::BranchAccumulator::get(call);
auto callNames = call->isReturn ? BranchUtils::NameSet{}
: BranchUtils::BranchAccumulator::get(call);
block->name = Names::getValidName(block->name, [&](Name test) {
return !fromNames.count(test) && !callNames.count(test);
});
}
if (call->isReturn) {
if (retType.isConcrete()) {
*action.callSite = builder.makeReturn(block);
} else {
*action.callSite = builder.makeSequence(block, builder.makeReturn());
}
} else {
*action.callSite = block;
}

// Prepare to update the inlined code's locals and other things.
Updater updater(options);
updater.setFunction(into);
updater.module = module;
updater.resultType = from->getResults();
updater.returnName = block->name;
updater.isReturn = call->isReturn;
updater.builder = &builder;
// Set up a locals mapping
for (Index i = 0; i < from->getNumLocals(); i++) {
updater.localMapping[i] = builder.addVar(into, from->getLocalType(i));
}
// Assign the operands into the params
for (Index i = 0; i < from->getParams().size(); i++) {
block->list.push_back(
builder.makeLocalSet(updater.localMapping[i], call->operands[i]));
}
// Zero out the vars (as we may be in a loop, and may depend on their
// zero-init value
for (Index i = 0; i < from->vars.size(); i++) {
auto type = from->vars[i];
if (!LiteralUtils::canMakeZero(type)) {
// Non-zeroable locals do not need to be zeroed out. As they have no zero
// value they by definition should not be used before being written to, so
// any value we set here would not be observed anyhow.
continue;

if (call->isReturn) {
// Wrap the existing function body in a block we can branch out of before
// entering the inlined function body. This block must have a name that is
// different from any other block name above the branch.
auto intoNames = BranchUtils::BranchAccumulator::get(into->body);
auto bodyName =
Names::getValidName(Name("__original_body"),
[&](Name test) { return !intoNames.count(test); });
if (retType.isConcrete()) {
into->body = builder.makeBlock(
bodyName, {builder.makeReturn(into->body)}, Type::none);
} else {
into->body = builder.makeBlock(
bodyName, {into->body, builder.makeReturn()}, Type::none);
}

// Sequence the inlined function body after the original caller body.
into->body = builder.makeSequence(into->body, block, retType);

// Replace the original callsite with an expression that assigns the
// operands into the params and branches out of the original body.
auto numParams = from->getParams().size();
if (numParams) {
auto* branchBlock = builder.makeBlock();
for (Index i = 0; i < numParams; i++) {
branchBlock->list.push_back(
builder.makeLocalSet(updater.localMapping[i], call->operands[i]));
}
branchBlock->list.push_back(builder.makeBreak(bodyName));
branchBlock->finalize(Type::unreachable);
*action.callSite = branchBlock;
} else {
*action.callSite = builder.makeBreak(bodyName);
}
} else {
// Assign the operands into the params
for (Index i = 0; i < from->getParams().size(); i++) {
block->list.push_back(
builder.makeLocalSet(updater.localMapping[i], call->operands[i]));
}
block->list.push_back(
builder.makeLocalSet(updater.localMapping[from->getVarIndexBase() + i],
LiteralUtils::makeZero(type, *module)));
// Zero out the vars (as we may be in a loop, and may depend on their
// zero-init value
for (Index i = 0; i < from->vars.size(); i++) {
auto type = from->vars[i];
if (!LiteralUtils::canMakeZero(type)) {
// Non-zeroable locals do not need to be zeroed out. As they have no
// zero value they by definition should not be used before being written
// to, so any value we set here would not be observed anyhow.
continue;
}
block->list.push_back(
builder.makeLocalSet(updater.localMapping[from->getVarIndexBase() + i],
LiteralUtils::makeZero(type, *module)));
}
*action.callSite = block;
}

// Generate and update the inlined contents
auto* contents = ExpressionManipulator::copy(from->body, *module);
debug::copyDebugInfo(from->body, contents, from, into);
updater.walk(contents);
block->list.push_back(contents);
block->type = retType;

// The ReFinalize below will handle propagating unreachability if we need to
// do so, that is, if the call was reachable but now the inlined content we
// replaced it with was unreachable. The opposite case requires special
Expand Down
19 changes: 16 additions & 3 deletions test/lit/passes/inlining-unreachable.wast
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,16 @@
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (block
;; CHECK-NEXT: (block $__inlined_func$callee
;; CHECK-NEXT: (call $imported
;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: (block
;; CHECK-NEXT: (block $__return_call
;; CHECK-NEXT: (block
;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: (br $__return_call)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: (call $imported
;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
Expand Down Expand Up @@ -114,7 +122,12 @@
;; CHECK-NEXT: (block $__inlined_func$0
;; CHECK-NEXT: (block
;; CHECK-NEXT: (nop)
;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: (block ;; (replaces unreachable CallRef we can't emit)
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (ref.null nofunc)
;; CHECK-NEXT: )
;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
Expand Down
7 changes: 6 additions & 1 deletion test/lit/passes/inlining_all-features.wast
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@
)
;; CHECK: (func $1 (type $none_=>_none)
;; CHECK-NEXT: (block $__inlined_func$0
;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: (block ;; (replaces unreachable CallRef we can't emit)
;; CHECK-NEXT: (drop
;; CHECK-NEXT: (ref.null nofunc)
;; CHECK-NEXT: )
;; CHECK-NEXT: (unreachable)
;; CHECK-NEXT: )
;; CHECK-NEXT: )
;; CHECK-NEXT: )
(func $1
Expand Down
Loading