diff --git a/src/abi/stack.h b/src/abi/stack.h index 93a6e4cc1e4..6d7ead3d941 100644 --- a/src/abi/stack.h +++ b/src/abi/stack.h @@ -84,7 +84,7 @@ getStackSpace(Index local, Function* func, Index size, Module& wasm) { block->list.push_back(makeStackRestore()); block->list.push_back( builder.makeReturn(builder.makeLocalGet(temp, ret->value->type))); - block->finalize(); + block->finalize(&wasm); *ptr = block; } else { // restore, then return @@ -105,7 +105,7 @@ getStackSpace(Index local, Function* func, Index size, Module& wasm) { block->list.push_back(makeStackRestore()); block->list.push_back(builder.makeLocalGet(temp, func->getResults())); } - block->finalize(); + block->finalize(&wasm); func->body = block; } diff --git a/src/binaryen-c.cpp b/src/binaryen-c.cpp index 59815678250..77599dc2877 100644 --- a/src/binaryen-c.cpp +++ b/src/binaryen-c.cpp @@ -1093,7 +1093,7 @@ BinaryenExpressionRef BinaryenBlock(BinaryenModuleRef module, if (type != BinaryenTypeAuto()) { ret->finalize(Type(type)); } else { - ret->finalize(); + ret->finalize((Module*)module); } return static_cast(ret); } @@ -1985,8 +1985,9 @@ void BinaryenExpressionSetType(BinaryenExpressionRef expr, BinaryenType type) { void BinaryenExpressionPrint(BinaryenExpressionRef expr) { std::cout << *(Expression*)expr << '\n'; } -void BinaryenExpressionFinalize(BinaryenExpressionRef expr) { - ReFinalizeNode().visit((Expression*)expr); +void BinaryenExpressionFinalize(BinaryenExpressionRef expr, + BinaryenModuleRef module) { + ReFinalizeNode(*(Module*)module).visit((Expression*)expr); } BinaryenExpressionRef BinaryenExpressionCopy(BinaryenExpressionRef expr, diff --git a/src/binaryen-c.h b/src/binaryen-c.h index ee56afdadf8..05f50f4cb0c 100644 --- a/src/binaryen-c.h +++ b/src/binaryen-c.h @@ -1169,7 +1169,8 @@ BINARYEN_API void BinaryenExpressionSetType(BinaryenExpressionRef expr, // Prints text format of the given expression to stdout. BINARYEN_API void BinaryenExpressionPrint(BinaryenExpressionRef expr); // Re-finalizes an expression after it has been modified. -BINARYEN_API void BinaryenExpressionFinalize(BinaryenExpressionRef expr); +BINARYEN_API void BinaryenExpressionFinalize(BinaryenExpressionRef expr, + BinaryenModuleRef module); // Makes a deep copy of the given expression. BINARYEN_API BinaryenExpressionRef BinaryenExpressionCopy(BinaryenExpressionRef expr, BinaryenModuleRef module); diff --git a/src/cfg/Relooper.cpp b/src/cfg/Relooper.cpp index 2ef757bf144..63c32ce56dc 100644 --- a/src/cfg/Relooper.cpp +++ b/src/cfg/Relooper.cpp @@ -64,10 +64,11 @@ static wasm::Expression* HandleFollowupMultiples(wasm::Expression* Ret, } for (auto& [Id, Body] : Multiple->InnerMap) { Curr->name = Builder.getBlockBreakName(Id); - Curr->finalize(); // it may now be reachable, via a break + Curr->finalize( + &Builder.getModule()); // it may now be reachable, via a break auto* Outer = Builder.makeBlock(Curr); Outer->list.push_back(Body->Render(Builder, InLoop)); - Outer->finalize(); // TODO: not really necessary + Outer->finalize(&Builder.getModule()); // TODO: not really necessary Curr = Outer; } Parent->Next = Parent->Next->Next; @@ -91,15 +92,15 @@ static wasm::Expression* HandleFollowupMultiples(wasm::Expression* Ret, } else { for (auto* Entry : Loop->Entries) { Curr->name = Builder.getBlockBreakName(Entry->Id); - Curr->finalize(); + Curr->finalize(&Builder.getModule()); auto* Outer = Builder.makeBlock(Curr); - Outer->finalize(); // TODO: not really necessary + Outer->finalize(&Builder.getModule()); // TODO: not really necessary Curr = Outer; } } } } - Curr->finalize(); + Curr->finalize(&Builder.getModule()); return Curr; } @@ -132,7 +133,7 @@ Branch::Render(RelooperBuilder& Builder, Block* Target, bool SetLabel) { assert(Ancestor); Ret->list.push_back(Builder.makeShapeContinue(Ancestor->Id)); } - Ret->finalize(); + Ret->finalize(&Builder.getModule()); return Ret; } @@ -170,7 +171,7 @@ wasm::Expression* Block::Render(RelooperBuilder& Builder, bool InLoop) { } if (!ProcessedBranchesOut.size()) { - Ret->finalize(); + Ret->finalize(&Builder.getModule()); return Ret; } @@ -407,7 +408,7 @@ wasm::Expression* Block::Render(RelooperBuilder& Builder, bool InLoop) { if (Root) { Ret->list.push_back(Root); } - Ret->finalize(); + Ret->finalize(&Builder.getModule()); return Ret; } diff --git a/src/ir/branch-utils.h b/src/ir/branch-utils.h index 3527f1b3693..0231606885e 100644 --- a/src/ir/branch-utils.h +++ b/src/ir/branch-utils.h @@ -68,7 +68,9 @@ template void operateOnScopeNameUses(Expression* expr, T func) { // Similar to operateOnScopeNameUses, but also passes in the type that is sent // if the branch is taken. The type is none if there is no value. template -void operateOnScopeNameUsesAndSentTypes(Expression* expr, T func) { +void operateOnScopeNameUsesAndSentTypes(Module& wasm, + Expression* expr, + T func) { operateOnScopeNameUses(expr, [&](Name& name) { // There isn't a delegate mechanism for getting a sent value, so do a direct // if-else chain. This will need to be updated with new br variants. @@ -281,19 +283,20 @@ struct BranchSeeker Index found = 0; - std::unordered_set types; - BranchSeeker(Name target) : target(target) {} - void noteFound(Type newType) { - found++; - types.insert(newType); - } + void noteFound() { found++; } void visitExpression(Expression* curr) { - operateOnScopeNameUsesAndSentTypes(curr, [&](Name& name, Type type) { + if (curr->is() || curr->is()) { + // Keep the ignored types of users in sync with + // operateOnScopeNameUsesAndSentTypes and + // operateOnScopeNameUsesAndSentValues + return; + } + operateOnScopeNameUses(curr, [&](Name& name) { if (name == target) { - noteFound(type); + noteFound(); } }); } @@ -317,6 +320,33 @@ struct BranchSeeker } }; +// Like BranchSeeker, but accumulates the types sent to the block in question. +struct BranchTypeSeeker + : public PostWalker> { + Name target; + Module& wasm; + + Index found = 0; + + std::unordered_set types; + + BranchTypeSeeker(Module& wasm, Name target) : target(target), wasm(wasm) {} + + void noteFound(Type newType) { + found++; + types.insert(newType); + } + + void visitExpression(Expression* curr) { + operateOnScopeNameUsesAndSentTypes(wasm, curr, [&](Name& name, Type type) { + if (name == target) { + noteFound(type); + } + }); + } +}; + // Accumulates all the branches in an entire tree. struct BranchAccumulator : public PostWalker> { + + Module& wasm; + + TypeUpdater(Module& wasm) : wasm(wasm) {} + // Part 1: Scanning // track names to their blocks, so that when we remove a break to @@ -172,8 +177,9 @@ struct TypeUpdater // adds (or removes) breaks depending on break/switch contents void discoverBreaks(Expression* curr, int change) { BranchUtils::operateOnScopeNameUsesAndSentTypes( - curr, - [&](Name& name, Type type) { noteBreakChange(name, change, type); }); + wasm, curr, [&](Name& name, Type type) { + noteBreakChange(name, change, type); + }); // TODO: it may be faster to accumulate all changes to a set first, then // call noteBreakChange on the unique values, as a switch can be quite // large and have lots of repeated targets. diff --git a/src/ir/utils.h b/src/ir/utils.h index 72aa701bb77..20fc0dbc33b 100644 --- a/src/ir/utils.h +++ b/src/ir/utils.h @@ -158,8 +158,16 @@ struct ReFinalize // Re-finalize a single node. This is slow, if you want to refinalize // an entire ast, use ReFinalize struct ReFinalizeNode : public OverriddenVisitor { + + ReFinalizeNode(Module& wasm) : wasm(wasm) {} + + template void static finalize(Module& wasm, T* curr) { + curr->finalize(); + } + void static finalize(Module& wasm, Block* curr) { curr->finalize(&wasm); } + #define DELEGATE(CLASS_TO_VISIT) \ - void visit##CLASS_TO_VISIT(CLASS_TO_VISIT* curr) { curr->finalize(); } + void visit##CLASS_TO_VISIT(CLASS_TO_VISIT* curr) { finalize(wasm, curr); } #include "wasm-delegations.def" @@ -173,12 +181,15 @@ struct ReFinalizeNode : public OverriddenVisitor { void visitModule(Module* curr) { WASM_UNREACHABLE("unimp"); } // given a stack of nested expressions, update them all from child to parent - static void updateStack(ExpressionStack& expressionStack) { + static void updateStack(Module& wasm, ExpressionStack& expressionStack) { for (int i = int(expressionStack.size()) - 1; i >= 0; i--) { auto* curr = expressionStack[i]; - ReFinalizeNode().visit(curr); + ReFinalizeNode(wasm).visit(curr); } } + +private: + Module& wasm; }; // Adds drop() operations where necessary. This lets you not worry about adding @@ -209,7 +220,9 @@ struct AutoDrop : public WalkerPass> { return acted; } - void reFinalize() { ReFinalizeNode::updateStack(expressionStack); } + void reFinalize() { + ReFinalizeNode::updateStack(*getModule(), expressionStack); + } void visitBlock(Block* curr) { if (curr->list.size() == 0) { diff --git a/src/passes/AlignmentLowering.cpp b/src/passes/AlignmentLowering.cpp index d0ceeb6107b..377fc50b175 100644 --- a/src/passes/AlignmentLowering.cpp +++ b/src/passes/AlignmentLowering.cpp @@ -235,7 +235,7 @@ struct AlignmentLowering : public WalkerPass> { } else { WASM_UNREACHABLE("invalid size"); } - block->finalize(); + block->finalize(getModule()); return block; } diff --git a/src/passes/Asyncify.cpp b/src/passes/Asyncify.cpp index b30af085e9c..84375be0b06 100644 --- a/src/passes/Asyncify.cpp +++ b/src/passes/Asyncify.cpp @@ -935,7 +935,7 @@ struct AsyncifyFlow : public Pass { // something valid (which the optimizer can remove later). block->list.push_back(builder->makeUnreachable()); } - block->finalize(); + block->finalize(module); func->body = block; // Making things like returns conditional may alter types. ReFinalize().walkFunctionInModule(func, module); @@ -1062,7 +1062,7 @@ struct AsyncifyFlow : public Pass { for (auto j = begin; j <= i; j++) { block->list.push_back(list[j]); } - block->finalize(); + block->finalize(module); list[begin] = makeMaybeSkip(block); for (auto j = begin + 1; j <= i; j++) { list[j] = builder->makeNop(); @@ -1536,7 +1536,7 @@ struct AsyncifyLocals : public WalkerPass> { } block->list.push_back(builder->makeLocalSet(i, load)); } - block->finalize(); + block->finalize(getModule()); return block; } @@ -1578,7 +1578,7 @@ struct AsyncifyLocals : public WalkerPass> { } } block->list.push_back(builder->makeIncStackPos(offset)); - block->finalize(); + block->finalize(getModule()); return block; } @@ -1831,7 +1831,7 @@ struct Asyncify : public Pass { builder.makeBinary( Abstract::getBinary(pointerType, Abstract::GtU), stackPos, stackEnd), builder.makeUnreachable())); - body->finalize(); + body->finalize(module); auto func = builder.makeFunction( name, Signature(Type(params), Type::none), {}, body); module->addFunction(std::move(func)); diff --git a/src/passes/CodeFolding.cpp b/src/passes/CodeFolding.cpp index 0e57a79a21c..aac2e3cd627 100644 --- a/src/passes/CodeFolding.cpp +++ b/src/passes/CodeFolding.cpp @@ -484,7 +484,11 @@ struct CodeFolding : public WalkerPass> { auto oldType = curr->type; // NB: we template-specialize so that this calls the proper finalizer for // the type - curr->finalize(); + if (Block* currBlock = curr->template dynCast()) { + currBlock->finalize(getModule()); + } else { + curr->finalize(); + } // ensure the replacement has the same type, so the outside is not surprised block->finalize(oldType); replaceCurrent(block); @@ -726,7 +730,7 @@ struct CodeFolding : public WalkerPass> { // change) auto* toplevel = old->dynCast(); if (toplevel) { - toplevel->finalize(); + toplevel->finalize(getModule()); } if (old->type != Type::unreachable) { inner->list.push_back(builder.makeReturn(old)); @@ -735,7 +739,7 @@ struct CodeFolding : public WalkerPass> { } } } - inner->finalize(); + inner->finalize(getModule()); auto* outer = builder.makeBlock(); outer->list.push_back(inner); while (!mergeable.empty()) { diff --git a/src/passes/DeadCodeElimination.cpp b/src/passes/DeadCodeElimination.cpp index 0b382f46574..9609109f884 100644 --- a/src/passes/DeadCodeElimination.cpp +++ b/src/passes/DeadCodeElimination.cpp @@ -53,7 +53,16 @@ struct DeadCodeElimination } // as we remove code, we must keep the types of other nodes valid - TypeUpdater typeUpdater; + // lazily initialized because it requires access to the module + std::optional tu; + + TypeUpdater& typeUpdater() { + if (!tu) { + assert(getModule()); + tu.emplace(*getModule()); + } + return *tu; + } Expression* replaceCurrent(Expression* expression) { auto* old = getCurrent(); @@ -62,12 +71,12 @@ struct DeadCodeElimination } super::replaceCurrent(expression); // also update the type updater - typeUpdater.noteReplacement(old, expression); + typeUpdater().noteReplacement(old, expression); return expression; } void doWalkFunction(Function* func) { - typeUpdater.walk(func->body); + typeUpdater().walk(func->body); walk(func->body); } @@ -91,7 +100,7 @@ struct DeadCodeElimination bool afterUnreachable = false; for (auto* child : ChildIterator(curr)) { if (afterUnreachable) { - typeUpdater.noteRecursiveRemoval(child); + typeUpdater().noteRecursiveRemoval(child); continue; } if (child->type == Type::unreachable) { @@ -125,7 +134,7 @@ struct DeadCodeElimination } if (removeFromHere != 0) { for (Index i = removeFromHere; i < list.size(); i++) { - typeUpdater.noteRecursiveRemoval(list[i]); + typeUpdater().noteRecursiveRemoval(list[i]); } list.resize(removeFromHere); if (list.size() == 1 && list[0]->is()) { @@ -138,14 +147,14 @@ struct DeadCodeElimination // have a concrete value flowing out) then remove it, which may allow // more reduction. if (block->type.isConcrete() && list.back()->type == Type::unreachable && - !typeUpdater.hasBreaks(block)) { - typeUpdater.changeType(block, Type::unreachable); + !typeUpdater().hasBreaks(block)) { + typeUpdater().changeType(block, Type::unreachable); } } else if (auto* iff = curr->dynCast()) { if (iff->condition->type == Type::unreachable) { - typeUpdater.noteRecursiveRemoval(iff->ifTrue); + typeUpdater().noteRecursiveRemoval(iff->ifTrue); if (iff->ifFalse) { - typeUpdater.noteRecursiveRemoval(iff->ifFalse); + typeUpdater().noteRecursiveRemoval(iff->ifFalse); } replaceCurrent(iff->condition); return; @@ -155,7 +164,7 @@ struct DeadCodeElimination if (iff->type != Type::unreachable && iff->ifFalse && iff->ifTrue->type == Type::unreachable && iff->ifFalse->type == Type::unreachable) { - typeUpdater.changeType(iff, Type::unreachable); + typeUpdater().changeType(iff, Type::unreachable); } } else if (auto* loop = curr->dynCast()) { // The loop body may have unreachable type if it branches back to the @@ -173,7 +182,7 @@ struct DeadCodeElimination } if (tryy->type != Type::unreachable && tryy->body->type == Type::unreachable && allCatchesUnreachable) { - typeUpdater.changeType(tryy, Type::unreachable); + typeUpdater().changeType(tryy, Type::unreachable); } } else { WASM_UNREACHABLE("unimplemented DCE control flow structure"); diff --git a/src/passes/Flatten.cpp b/src/passes/Flatten.cpp index 3c0aa78a0ab..719049b3a28 100644 --- a/src/passes/Flatten.cpp +++ b/src/passes/Flatten.cpp @@ -169,7 +169,7 @@ struct Flatten } iff->finalize(); if (prelude) { - ReFinalizeNode().visit(prelude); + ReFinalizeNode(*getModule()).visit(prelude); ourPreludes.push_back(prelude); } replaceCurrent(rep); @@ -332,7 +332,7 @@ struct Flatten // continue for general handling of everything, control flow or otherwise curr = getCurrent(); // we may have replaced it // we have changed children - ReFinalizeNode().visit(curr); + ReFinalizeNode(*getModule()).visit(curr); if (curr->type == Type::unreachable) { ourPreludes.push_back(curr); replaceCurrent(builder.makeUnreachable()); @@ -393,7 +393,7 @@ struct Flatten auto* ret = Builder(*getModule()).makeBlock(thePreludes); thePreludes.clear(); ret->list.push_back(after); - ret->finalize(); + ret->finalize(getModule()); return ret; } diff --git a/src/passes/JSPI.cpp b/src/passes/JSPI.cpp index 1adf0de4ca2..df90025070a 100644 --- a/src/passes/JSPI.cpp +++ b/src/passes/JSPI.cpp @@ -212,7 +212,7 @@ struct JSPI : public Pass { resultsType = Type::i32; block->list.push_back(builder.makeConst(0)); } - block->finalize(); + block->finalize(module); auto wrapperFunc = Builder::makeFunction(wrapperName, std::move(namedWrapperParams), @@ -273,7 +273,7 @@ struct JSPI : public Pass { block->list.push_back( builder.makeLocalGet(*returnIndex, stub->getResults())); } - block->finalize(); + block->finalize(module); call->type = im->getResults(); stub->body = block; wrapperIm->type = Signature(Type(params), call->type); diff --git a/src/passes/LegalizeJSInterface.cpp b/src/passes/LegalizeJSInterface.cpp index 3ca6b7095e0..1e07c1711ab 100644 --- a/src/passes/LegalizeJSInterface.cpp +++ b/src/passes/LegalizeJSInterface.cpp @@ -278,7 +278,7 @@ struct LegalizeJSInterface : public Pass { {I64Utilities::getI64High(builder, index)}, Type::none)); block->list.push_back(I64Utilities::getI64Low(builder, index)); - block->finalize(); + block->finalize(module); legal->body = block; } else { legal->body = call; diff --git a/src/passes/MergeBlocks.cpp b/src/passes/MergeBlocks.cpp index 9a0b5d27be5..765e119a2d6 100644 --- a/src/passes/MergeBlocks.cpp +++ b/src/passes/MergeBlocks.cpp @@ -254,7 +254,7 @@ static void optimizeBlock(Block* curr, drop->finalize(); childBlock->list.back() = drop; } - childBlock->finalize(); + childBlock->finalize(module); child = list[i] = childBlock; more = true; changed = true; @@ -362,7 +362,7 @@ static void optimizeBlock(Block* curr, // Update the child. childList.swap(filtered); // We may have removed unreachable items. - childBlock->finalize(); + childBlock->finalize(module); if (loop) { loop->finalize(); } diff --git a/src/passes/Poppify.cpp b/src/passes/Poppify.cpp index a7f3e5cccd7..6d3d6edf177 100644 --- a/src/passes/Poppify.cpp +++ b/src/passes/Poppify.cpp @@ -144,7 +144,8 @@ struct Poppifier : BinaryenIRWriter { }; Poppifier::Poppifier(Function* func, Module* module) - : BinaryenIRWriter(func), module(module), builder(*module) { + : BinaryenIRWriter(*module, func), module(module), + builder(*module) { // Start with a scope to emit top-level instructions into scopeStack.emplace_back(Scope::Func); diff --git a/src/passes/ReReloop.cpp b/src/passes/ReReloop.cpp index d195846f6ce..07545229255 100644 --- a/src/passes/ReReloop.cpp +++ b/src/passes/ReReloop.cpp @@ -53,7 +53,7 @@ struct ReReloop final : public Pass { CFG::Block* setCurrCFGBlock(CFG::Block* curr) { if (currCFGBlock) { - finishBlock(); + finishBlock(*curr->relooper->Module); } return currCFGBlock = curr; } @@ -64,7 +64,7 @@ struct ReReloop final : public Pass { Block* getCurrBlock() { return currCFGBlock->Code->cast(); } - void finishBlock() { getCurrBlock()->finalize(); } + void finishBlock(Module& wasm) { getCurrBlock()->finalize(&wasm); } // break handling @@ -310,7 +310,7 @@ struct ReReloop final : public Pass { curr->run(); } // finish the current block - finishBlock(); + finishBlock(*module); // blocks that do not have any exits are dead ends in the relooper. we need // to make sure that are in fact dead ends, and do not flow control // anywhere. add a return as needed @@ -320,7 +320,7 @@ struct ReReloop final : public Pass { block->list.push_back(function->getResults() == Type::none ? (Expression*)builder->makeReturn() : (Expression*)builder->makeUnreachable()); - block->finalize(); + block->finalize(module); } } #ifdef RERELOOP_DEBUG diff --git a/src/passes/RemoveUnusedBrs.cpp b/src/passes/RemoveUnusedBrs.cpp index 576f2fe9129..7080473e359 100644 --- a/src/passes/RemoveUnusedBrs.cpp +++ b/src/passes/RemoveUnusedBrs.cpp @@ -46,7 +46,7 @@ stealSlice(Builder& builder, Block* input, Index from, Index to) { for (Index i = from; i < to; i++) { block->list.push_back(input->list[i]); } - block->finalize(); + block->finalize(&builder.getModule()); ret = block; } if (to == input->list.size()) { @@ -537,7 +537,7 @@ struct RemoveUnusedBrs : public WalkerPass> { if (iff->ifTrue->type == Type::unreachable) { iff->ifFalse = stealSlice(builder, block, i + 1, list.size()); iff->finalize(); - block->finalize(); + block->finalize(getModule()); return true; } } else { @@ -574,7 +574,7 @@ struct RemoveUnusedBrs : public WalkerPass> { block->list.push_back(item); } } - block->finalize(); + block->finalize(getModule()); return block; }; @@ -582,13 +582,13 @@ struct RemoveUnusedBrs : public WalkerPass> { iff->ifFalse = blockifyMerge( iff->ifFalse, stealSlice(builder, block, i + 1, list.size())); iff->finalize(); - block->finalize(); + block->finalize(getModule()); return true; } else if (iff->ifFalse->type == Type::unreachable) { iff->ifTrue = blockifyMerge( iff->ifTrue, stealSlice(builder, block, i + 1, list.size())); iff->finalize(); - block->finalize(); + block->finalize(getModule()); return true; } } @@ -622,7 +622,7 @@ struct RemoveUnusedBrs : public WalkerPass> { builder.makeIf(brIf->condition, builder.makeBreak(brIf->name), stealSlice(builder, block, i + 1, list.size())); - block->finalize(); + block->finalize(getModule()); return true; } } @@ -643,8 +643,11 @@ struct RemoveUnusedBrs : public WalkerPass> { bool sinkBlocks(Function* func) { struct Sinker : public PostWalker { + Module& wasm; bool worked = false; + Sinker(Module& wasm) : wasm(wasm) {} + void visitBlock(Block* curr) { // If the block has a single child which is a loop, and the block is // named, then it is the exit for the loop. It's better to move it into @@ -695,7 +698,7 @@ struct RemoveUnusedBrs : public WalkerPass> { // The block used to contain the if, and may have changed type // from unreachable to none, for example, if the if has an // unreachable condition but the arm is not unreachable. - curr->finalize(); + curr->finalize(&wasm); iff->finalize(); replaceCurrent(iff); worked = true; @@ -707,7 +710,9 @@ struct RemoveUnusedBrs : public WalkerPass> { } } } - } sinker; + }; + + Sinker sinker(*getModule()); sinker.doWalkFunction(func); if (sinker.worked) { diff --git a/src/passes/SimplifyLocals.cpp b/src/passes/SimplifyLocals.cpp index 28b13980b2a..c69a2e08480 100644 --- a/src/passes/SimplifyLocals.cpp +++ b/src/passes/SimplifyLocals.cpp @@ -490,7 +490,7 @@ struct SimplifyLocals auto* set = (*item)->template cast(); block->list[block->list.size() - 1] = set->value; *item = builder.makeNop(); - block->finalize(); + block->finalize(this->getModule()); assert(block->type != Type::none); loop->finalize(); set->value = loop; @@ -626,7 +626,7 @@ struct SimplifyLocals this->replaceCurrent(newLocalSet); sinkables.clear(); anotherCycle = true; - block->finalize(); + block->finalize(this->getModule()); } // optimize local.sets from both sides of an if into a return value @@ -710,7 +710,7 @@ struct SimplifyLocals ifTrueBlock->list[ifTrueBlock->list.size() - 1] = (*ifTrueItem)->template cast()->value; ExpressionManipulator::nop(*ifTrueItem); - ifTrueBlock->finalize(); + ifTrueBlock->finalize(this->getModule()); assert(ifTrueBlock->type != Type::none); } if (iff->ifFalse->type != Type::unreachable) { @@ -718,7 +718,7 @@ struct SimplifyLocals ifFalseBlock->list[ifFalseBlock->list.size() - 1] = (*ifFalseItem)->template cast()->value; ExpressionManipulator::nop(*ifFalseItem); - ifFalseBlock->finalize(); + ifFalseBlock->finalize(this->getModule()); assert(ifFalseBlock->type != Type::none); } iff->finalize(); // update type @@ -822,7 +822,7 @@ struct SimplifyLocals auto* set = (*item)->template cast(); ifTrueBlock->list[ifTrueBlock->list.size() - 1] = set->value; *item = builder.makeNop(); - ifTrueBlock->finalize(); + ifTrueBlock->finalize(this->getModule()); assert(ifTrueBlock->type != Type::none); // Update the ifFalse side. iff->ifFalse = builder.makeLocalGet(set->index, localType); diff --git a/src/passes/SpillPointers.cpp b/src/passes/SpillPointers.cpp index 3af7039cdd3..316db2f0d7f 100644 --- a/src/passes/SpillPointers.cpp +++ b/src/passes/SpillPointers.cpp @@ -170,7 +170,7 @@ struct SpillPointers auto temp = builder.addVar(func, operand->type); auto* set = builder.makeLocalSet(temp, operand); block->list.push_back(set); - block->finalize(); + block->finalize(module); if (actualPointers.count(&operand) > 0) { // this is something we track, and it's moving - update actualPointers[&operand] = &set->value; @@ -202,7 +202,7 @@ struct SpillPointers } // add the (modified) call block->list.push_back(call); - block->finalize(); + block->finalize(module); *origin = block; } }; diff --git a/src/passes/Vacuum.cpp b/src/passes/Vacuum.cpp index 41a8e02da8a..489f37bce5a 100644 --- a/src/passes/Vacuum.cpp +++ b/src/passes/Vacuum.cpp @@ -367,7 +367,7 @@ struct Vacuum : public WalkerPass> { // we may be able to remove this, if there are no brs bool canPop = true; if (block->name.is()) { - BranchUtils::BranchSeeker seeker(block->name); + BranchUtils::BranchTypeSeeker seeker(*getModule(), block->name); Expression* temp = block; seeker.walk(temp); if (seeker.found && Type::hasLeastUpperBound(seeker.types)) { diff --git a/src/tools/fuzzing/fuzzing.cpp b/src/tools/fuzzing/fuzzing.cpp index d9e0f6baf77..a9b92a63a3e 100644 --- a/src/tools/fuzzing/fuzzing.cpp +++ b/src/tools/fuzzing/fuzzing.cpp @@ -1351,7 +1351,7 @@ Expression* TranslateToFuzzReader::makeBlock(Type type) { if (type.isConcrete()) { ret->finalize(type); } else { - ret->finalize(); + ret->finalize(&builder.getModule()); } if (ret->type != type) { // e.g. we might want an unreachable block, but a child breaks to it diff --git a/src/wasm-builder.h b/src/wasm-builder.h index 9e2881145b9..0d7b394a97e 100644 --- a/src/wasm-builder.h +++ b/src/wasm-builder.h @@ -40,6 +40,8 @@ class Builder { public: Builder(Module& wasm) : wasm(wasm) {} + Module& getModule() { return wasm; } + // make* functions create an expression instance. static std::unique_ptr makeFunction(Name name, @@ -172,14 +174,14 @@ class Builder { auto* ret = wasm.allocator.alloc(); if (first) { ret->list.push_back(first); - ret->finalize(); + ret->finalize(&wasm); } return ret; } Block* makeBlock(Name name, Expression* first = nullptr) { auto* ret = makeBlock(first); ret->name = name; - ret->finalize(); + ret->finalize(&wasm); return ret; } @@ -192,7 +194,7 @@ class Builder { Block* makeBlock(const T& items) { auto* ret = wasm.allocator.alloc(); ret->list.set(items); - ret->finalize(); + ret->finalize(&wasm); return ret; } @@ -1316,7 +1318,7 @@ class Builder { } if (append) { block->list.push_back(append); - block->finalize(); + block->finalize(&wasm); } return block; } @@ -1341,7 +1343,7 @@ class Builder { block->name = name; if (append) { block->list.push_back(append); - block->finalize(); + block->finalize(&wasm); } return block; } @@ -1351,7 +1353,7 @@ class Builder { Block* makeSequence(Expression* left, Expression* right) { auto* block = makeBlock(left); block->list.push_back(right); - block->finalize(); + block->finalize(&wasm); return block; } diff --git a/src/wasm-stack.h b/src/wasm-stack.h index 1f66212ad81..f4abd9f5ac8 100644 --- a/src/wasm-stack.h +++ b/src/wasm-stack.h @@ -159,7 +159,7 @@ class BinaryInstWriter : public OverriddenVisitor { template class BinaryenIRWriter : public Visitor> { public: - BinaryenIRWriter(Function* func) : func(func) {} + BinaryenIRWriter(Module& wasm, Function* func) : wasm(wasm), func(func) {} void write(); @@ -172,6 +172,7 @@ class BinaryenIRWriter : public Visitor> { void visitTry(Try* curr); protected: + Module& wasm; Function* func = nullptr; private: @@ -413,8 +414,9 @@ class BinaryenIRToBinaryWriter Function* func = nullptr, bool sourceMap = false, bool DWARF = false) - : BinaryenIRWriter(func), parent(parent), - writer(parent, o, func, sourceMap, DWARF), sourceMap(sourceMap) {} + : BinaryenIRWriter(*parent.getModule(), func), + parent(parent), writer(parent, o, func, sourceMap, DWARF), + sourceMap(sourceMap) {} void emit(Expression* curr) { writer.visit(curr); } void emitHeader() { @@ -454,7 +456,7 @@ class BinaryenIRToBinaryWriter class StackIRGenerator : public BinaryenIRWriter { public: StackIRGenerator(Module& module, Function* func) - : BinaryenIRWriter(func), module(module) {} + : BinaryenIRWriter(module, func), module(module) {} void emit(Expression* curr); void emitScopeEnd(Expression* curr); diff --git a/src/wasm.h b/src/wasm.h index ba797d77375..61e9ef4f1d7 100644 --- a/src/wasm.h +++ b/src/wasm.h @@ -42,6 +42,8 @@ namespace wasm { +class Module; + // An index in a wasm module using Index = uint32_t; @@ -824,9 +826,11 @@ class Block : public SpecificExpression { Name name; ExpressionList list; + void finalize(); + // set the type purely based on its contents. this scans the block, so it is // not fast. - void finalize(); + void finalize(Module* wasm); // set the type given you know its type, which is the case when parsing // s-expression or binary, as explicit types are given. the only additional diff --git a/src/wasm/wasm-binary.cpp b/src/wasm/wasm-binary.cpp index 3c01e5df5ff..8a3833427ce 100644 --- a/src/wasm/wasm-binary.cpp +++ b/src/wasm/wasm-binary.cpp @@ -3106,7 +3106,7 @@ Expression* WasmBinaryReader::popNonVoidExpression() { assert(type == Type::unreachable); // nothing to do here - unreachable anyhow } - block->finalize(); + block->finalize(&wasm); return block; } diff --git a/src/wasm/wasm-ir-builder.cpp b/src/wasm/wasm-ir-builder.cpp index 255683fbc72..d5edceacb75 100644 --- a/src/wasm/wasm-ir-builder.cpp +++ b/src/wasm/wasm-ir-builder.cpp @@ -192,7 +192,7 @@ Result<> IRBuilder::visit(Expression* curr) { } else { // TODO: Call more efficient versions of finalize() that take the known type // for other kinds of nodes as well, as done above. - ReFinalizeNode{}.visit(curr); + ReFinalizeNode{wasm}.visit(curr); } push(curr); return Ok{}; diff --git a/src/wasm/wasm-s-parser.cpp b/src/wasm/wasm-s-parser.cpp index 7b415813ec6..a775cf80833 100644 --- a/src/wasm/wasm-s-parser.cpp +++ b/src/wasm/wasm-s-parser.cpp @@ -1697,7 +1697,7 @@ Expression* SExpressionWasmBuilder::makeThenOrElse(Element& s) { for (; i < s.size(); i++) { ret->list.push_back(parseExpression(s[i])); } - ret->finalize(); + ret->finalize(&wasm); return ret; } diff --git a/src/wasm/wasm-validator.cpp b/src/wasm/wasm-validator.cpp index 68d1f786df6..03a222901a2 100644 --- a/src/wasm/wasm-validator.cpp +++ b/src/wasm/wasm-validator.cpp @@ -3309,18 +3309,20 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) { struct BinaryenIRValidator : public PostWalker> { + Module& wasm; ValidationInfo& info; std::unordered_set seen; - BinaryenIRValidator(ValidationInfo& info) : info(info) {} + BinaryenIRValidator(Module& wasm, ValidationInfo& info) + : wasm(wasm), info(info) {} void visitExpression(Expression* curr) { auto scope = getFunction() ? getFunction()->name : Name("(global scope)"); // check if a node type is 'stale', i.e., we forgot to finalize() the // node. auto oldType = curr->type; - ReFinalizeNode().visit(curr); + ReFinalizeNode(wasm).visit(curr); auto newType = curr->type; if (newType != oldType) { // We accept concrete => undefined on control flow structures: @@ -3357,7 +3359,7 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) { } } }; - BinaryenIRValidator binaryenIRValidator(info); + BinaryenIRValidator binaryenIRValidator(wasm, info); binaryenIRValidator.walkModule(&wasm); } diff --git a/src/wasm/wasm.cpp b/src/wasm/wasm.cpp index 24479c1cf5b..db52f8d8a67 100644 --- a/src/wasm/wasm.cpp +++ b/src/wasm/wasm.cpp @@ -175,6 +175,10 @@ handleUnreachable(Block* block, } void Block::finalize() { + // FIXME(frank-emrich) Anything to do here? +} + +void Block::finalize(Module* wasm) { if (list.size() == 0) { type = Type::none; return; @@ -189,7 +193,7 @@ void Block::finalize() { } // The default type is according to the value that flows out. - BranchUtils::BranchSeeker seeker(this->name); + BranchUtils::BranchTypeSeeker seeker(*wasm, this->name); Expression* temp = this; seeker.walk(temp); if (seeker.found) {