@@ -129,7 +129,10 @@ const std::vector<IterDomain*>& findReferenceLoopDomain(
129129
130130Expr* cloneWithNewOperands (
131131 Expr* e,
132- const std::unordered_map<Val*, Val*>& replacement_map) {
132+ const std::unordered_map<Val*, Val*>& replacement_map,
133+ bool output_is_preallocated) {
134+ NVF_ERROR (!e->outputIsPreallocated ());
135+
133136 auto maybe_replace = [&](Val*& x) -> bool {
134137 Val* new_x = getOrDefault (replacement_map, x);
135138 if (new_x == nullptr ) {
@@ -147,10 +150,16 @@ Expr* cloneWithNewOperands(
147150 std::vector<Val*> new_outs = e->outputs ();
148151 replaced += std::ranges::count_if (new_outs, maybe_replace);
149152
150- if (replaced == 0 ) {
153+ if (replaced == 0 && !output_is_preallocated ) {
151154 return e;
152155 }
153- return e->newObjectFunc ()(e->container (), new_ins, new_outs, e->attributes ());
156+
157+ Expr* new_e =
158+ e->newObjectFunc ()(e->container (), new_ins, new_outs, e->attributes ());
159+ if (output_is_preallocated) {
160+ new_e = new_e->withOutputPreallocated ();
161+ }
162+ return new_e;
154163}
155164
156165void lowerSegment (
@@ -213,7 +222,7 @@ void lowerSegment(
213222 innermost_scope.push_back (allocate);
214223 }
215224
216- Expr* new_c = cloneWithNewOperands (c, replacement_map);
225+ Expr* new_c = cloneWithNewOperands (c, replacement_map, true );
217226 innermost_scope.push_back (new_c);
218227
219228 auto * wait = IrBuilder::create<hir::Wait>(new_c);
@@ -267,10 +276,12 @@ void lowerSegment(
267276 }
268277 }
269278
279+ bool output_is_preallocated = false ;
270280 for (auto * out : ir_utils::filterByType<TensorView>(e->outputs ())) {
271281 if (getShardedIterDomain (out, ParallelType::Stream) == nullptr ) {
272282 auto * allocate =
273283 IrBuilder::create<kir::Allocate>(out, MemoryType::Global);
284+ output_is_preallocated = true ;
274285 innermost.parent_scope ->insert (
275286 innermost.parent_insertion_point , allocate);
276287 // Loop is stream parallelized but allocation is not. Therefore,
@@ -285,7 +296,8 @@ void lowerSegment(
285296 }
286297 }
287298
288- Expr* new_e = cloneWithNewOperands (e, replacement_map);
299+ Expr* new_e =
300+ cloneWithNewOperands (e, replacement_map, output_is_preallocated);
289301 innermost_scope.push_back (new_e);
290302 }
291303 break ;
0 commit comments