Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions csrc/host_ir/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,36 @@ const std::vector<IterDomain*>& findReferenceLoopDomain(
return reference_tv->getLoopDomain();
}

// Returns a new Expr with the inputs and outputs replaced by the replacement
// map. If none of the inputs or outputs are replaced, returns the original
// Expr.
Expr* cloneWithNewOperands(
Expr* e,
const std::unordered_map<Val*, Val*>& replacement_map) {
auto maybe_replace = [&](Val*& x) -> bool {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also tried to use OptOutMutator (#5538) but ran into

[ RUN      ] StreamTest.Matmul
unknown file: Failure
C++ exception with description " INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/mutator.cpp:90, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. 
Expected !DependencyCheck::isDependencyOf(val, mutation) . Attempted to replace a val, T1_g_float[istreamIdx7{3}, iS11{i2}, iS8{( ceilDiv(i4, 3) )}], with a dependent val, T3_l_float[istreamIdx14{3}, iS12{i2}, iS15{( ceilDiv(i4, 3) )}] (T3_l_float[istreamIdx14{3}, iS12{i2}, iS15{( ceilDiv(i4, 3) )}]), which is not allowed as it would result in a recursive definition of T3_l_float[istreamIdx14{3}, iS12{i2}, iS15{( ceilDiv(i4, 3) )}]
Exception raised from registerMutation at /opt/pytorch/nvfuser/csrc/mutator.cpp:90 (most recent call first):

Let me know if there's an easier way.

Val* new_x = getOrDefault(replacement_map, x);
if (new_x == nullptr) {
return false;
}
x = new_x;
return true;
};

int64_t replaced = 0;

std::vector<Val*> new_ins = e->inputs();
replaced += std::ranges::count_if(new_ins, maybe_replace);

std::vector<Val*> new_outs = e->outputs();
replaced += std::ranges::count_if(new_outs, maybe_replace);

if (replaced == 0) {
return e;
}

return e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes());
}

void lowerSegment(
const SegmentedGroup& group,
const AliasInfoMap& aliases,
Expand Down Expand Up @@ -219,24 +249,7 @@ void lowerSegment(
}
}

std::vector<Val*> new_inputs;
std::transform(
e->inputs().begin(),
e->inputs().end(),
std::back_inserter(new_inputs),
[&replacement_map](Val* input) {
return getOrDefault(replacement_map, input, input);
});
std::vector<Val*> new_outputs;
std::transform(
e->outputs().begin(),
e->outputs().end(),
std::back_inserter(new_outputs),
[&replacement_map](Val* output) {
return getOrDefault(replacement_map, output, output);
});
Expr* new_e = e->newObjectFunc()(
e->container(), new_inputs, new_outputs, e->attributes());
Expr* new_e = cloneWithNewOperands(e, replacement_map);
for_loop->body().push_back(new_e);
}
break;
Expand Down