@@ -115,7 +115,10 @@ const std::vector<IterDomain*>& findReferenceLoopDomain(
115115
116116Expr* cloneWithNewOperands (
117117 Expr* e,
118- const std::unordered_map<Val*, Val*>& replacement_map) {
118+ const std::unordered_map<Val*, Val*>& replacement_map,
119+ bool output_is_preallocated) {
120+ NVF_ERROR (!e->outputIsPreallocated ());
121+
119122 auto maybe_replace = [&](Val*& x) -> bool {
120123 Val* new_x = getOrDefault (replacement_map, x);
121124 if (new_x == nullptr ) {
@@ -133,10 +136,16 @@ Expr* cloneWithNewOperands(
133136 std::vector<Val*> new_outs = e->outputs ();
134137 replaced += std::ranges::count_if (new_outs, maybe_replace);
135138
136- if (replaced == 0 ) {
139+ if (replaced == 0 && !output_is_preallocated ) {
137140 return e;
138141 }
139- return e->newObjectFunc ()(e->container (), new_ins, new_outs, e->attributes ());
142+
143+ Expr* new_e =
144+ e->newObjectFunc ()(e->container (), new_ins, new_outs, e->attributes ());
145+ if (output_is_preallocated) {
146+ new_e = new_e->withOutputPreallocated ();
147+ }
148+ return new_e;
140149}
141150
142151void lowerSegment (
@@ -204,7 +213,7 @@ void lowerSegment(
204213 innermost_scope.push_back (allocate);
205214 }
206215
207- Expr* new_c = cloneWithNewOperands (c, replacement_map);
216+ Expr* new_c = cloneWithNewOperands (c, replacement_map, true );
208217 innermost_scope.push_back (new_c);
209218
210219 auto * wait = IrBuilder::create<hir::Wait>(new_c);
@@ -261,12 +270,14 @@ void lowerSegment(
261270 }
262271 }
263272
273+ bool output_is_preallocated = false ;
264274 for (auto * out : ir_utils::filterByType<TensorView>(e->outputs ())) {
265275 if (getShardedIterDomain (
266276 out, ParallelType::Stream, DomainType::kAllocation ) ==
267277 nullptr ) {
268278 auto * allocate =
269279 IrBuilder::create<kir::Allocate>(out, MemoryType::Global);
280+ output_is_preallocated = true ;
270281 innermost.parent_scope ->insert (
271282 innermost.parent_insertion_point , allocate);
272283 // Loop is stream parallelized but allocation is not. Therefore,
@@ -281,7 +292,8 @@ void lowerSegment(
281292 }
282293 }
283294
284- Expr* new_e = cloneWithNewOperands (e, replacement_map);
295+ Expr* new_e =
296+ cloneWithNewOperands (e, replacement_map, output_is_preallocated);
285297 innermost_scope.push_back (new_e);
286298 }
287299 break ;
0 commit comments