Skip to content

Commit 6111986

Browse files
committed
Fix how HostIrEvaluator detects pre-allocated outputs
1 parent 39b1f29 commit 6111986

File tree

9 files changed

+56
-27
lines changed

9 files changed

+56
-27
lines changed

csrc/host_ir/evaluator.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -482,27 +482,24 @@ void HostIrEvaluator::handle(MatmulOp* matmul) {
482482
TensorView* b = matmul->inB();
483483
TensorView* out = matmul->out();
484484

485-
if (expr_evaluator_.isKnown(out)) {
486-
auto t_a = getKnownConcreteValue(a).as<at::Tensor>();
487-
auto t_b = getKnownConcreteValue(b).as<at::Tensor>();
488-
auto t_out = getKnownConcreteValue(out).as<at::Tensor>();
489-
at::matmul_out(t_out, t_a, t_b);
490-
} else {
485+
if (!matmul->outputIsPreallocated()) {
491486
unhandled(matmul);
487+
return;
492488
}
489+
490+
auto t_a = getKnownConcreteValue(a).as<at::Tensor>();
491+
auto t_b = getKnownConcreteValue(b).as<at::Tensor>();
492+
auto t_out = getKnownConcreteValue(out).as<at::Tensor>();
493+
at::matmul_out(t_out, t_a, t_b);
493494
}
494495

495496
void HostIrEvaluator::handle(LinearOp* linear) {
496497
auto* in = linear->inA()->as<TensorView>();
497498
auto* weight = linear->inB()->as<TensorView>();
498499
auto* out = linear->out()->as<TensorView>();
499500

500-
// FIXME: When LinearOp is called in a for loop, even if it's output is not
501-
// pre-allocated, the second iteration will see isKnown true and skip the
502-
// unhandled path.
503-
if (!expr_evaluator_.isKnown(out)) {
504-
unhandled(linear);
505-
return;
501+
if (!linear->outputIsPreallocated()) {
502+
return unhandled(linear);
506503
}
507504

508505
auto in_tensor = getKnownConcreteValue(in).as<at::Tensor>();

csrc/host_ir/lowering.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ const std::vector<IterDomain*>& findReferenceLoopDomain(
115115

116116
Expr* cloneWithNewOperands(
117117
Expr* e,
118-
const std::unordered_map<Val*, Val*>& replacement_map) {
118+
const std::unordered_map<Val*, Val*>& replacement_map,
119+
bool output_is_preallocated) {
120+
NVF_ERROR(!e->outputIsPreallocated());
121+
119122
auto maybe_replace = [&](Val*& x) -> bool {
120123
Val* new_x = getOrDefault(replacement_map, x);
121124
if (new_x == nullptr) {
@@ -133,10 +136,16 @@ Expr* cloneWithNewOperands(
133136
std::vector<Val*> new_outs = e->outputs();
134137
replaced += std::ranges::count_if(new_outs, maybe_replace);
135138

136-
if (replaced == 0) {
139+
if (replaced == 0 && !output_is_preallocated) {
137140
return e;
138141
}
139-
return e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes());
142+
143+
Expr* new_e =
144+
e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes());
145+
if (output_is_preallocated) {
146+
new_e = new_e->withOutputPreallocated();
147+
}
148+
return new_e;
140149
}
141150

142151
void lowerSegment(
@@ -204,7 +213,7 @@ void lowerSegment(
204213
innermost_scope.push_back(allocate);
205214
}
206215

207-
Expr* new_c = cloneWithNewOperands(c, replacement_map);
216+
Expr* new_c = cloneWithNewOperands(c, replacement_map, true);
208217
innermost_scope.push_back(new_c);
209218

210219
auto* wait = IrBuilder::create<hir::Wait>(new_c);
@@ -261,12 +270,14 @@ void lowerSegment(
261270
}
262271
}
263272

273+
bool output_is_preallocated = false;
264274
for (auto* out : ir_utils::filterByType<TensorView>(e->outputs())) {
265275
if (getShardedIterDomain(
266276
out, ParallelType::Stream, DomainType::kAllocation) ==
267277
nullptr) {
268278
auto* allocate =
269279
IrBuilder::create<kir::Allocate>(out, MemoryType::Global);
280+
output_is_preallocated = true;
270281
innermost.parent_scope->insert(
271282
innermost.parent_insertion_point, allocate);
272283
// Loop is stream parallelized but allocation is not. Therefore,
@@ -281,7 +292,8 @@ void lowerSegment(
281292
}
282293
}
283294

284-
Expr* new_e = cloneWithNewOperands(e, replacement_map);
295+
Expr* new_e =
296+
cloneWithNewOperands(e, replacement_map, output_is_preallocated);
285297
innermost_scope.push_back(new_e);
286298
}
287299
break;

csrc/host_ir/pass/stream_parallel_type.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ std::list<Expr*> processForLoopBodies(
475475
ir_utils::filterByType<TensorView>(body_expr->outputs())) {
476476
processTensor(body_expr, output, tensor_index);
477477
}
478+
body_expr = body_expr->withOutputPreallocated();
478479
new_loop_body.push_back(body_expr);
479480
}
480481
}

csrc/ir/base_nodes.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ std::optional<DataType> Val::getDataType() const {
253253
// after inputs and outputs are registered with the Expr
254254
Expr::Expr(IrBuilderPasskey passkey) : Statement(passkey) {}
255255

256+
// FIXME: Should this constructor copy the output_is_preallocated_ flag?
256257
Expr::Expr(const Expr* src, IrCloner* ir_cloner)
257258
: Statement(src, ir_cloner),
258259
attributes_(ir_cloner->clone(src->attributes_)),
@@ -270,12 +271,13 @@ Expr::Expr(
270271
outputs_(std::move(outputs)) {}
271272

272273
Expr* Expr::shallowCopy() const {
273-
auto result =
274+
Expr* result =
274275
newObjectFunc()(ir_container_, inputs(), outputs(), attributes());
275276
if (container()->isA<kir::Kernel>()) {
276277
result->predicate_ = predicate_;
277278
result->write_predicate_ = write_predicate_;
278279
}
280+
result->output_is_preallocated_ = output_is_preallocated_;
279281
return result;
280282
}
281283

@@ -383,6 +385,11 @@ Expr* Expr::withWritePredicate(kir::Predicate* predicate) {
383385
return result;
384386
}
385387

388+
Expr* Expr::withOutputPreallocated() {
389+
output_is_preallocated_ = true;
390+
return this;
391+
}
392+
386393
std::vector<PolymorphicValue> Expr::evaluate(
387394
const ExpressionEvaluator& ee,
388395
const std::vector<PolymorphicValue>& inputs) const {

csrc/ir/base_nodes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,12 @@ class NVF_API Expr : public Statement {
599599
// TODO: Protect based on being in kernel container
600600
Expr* withWritePredicate(kir::Predicate* write_predicate);
601601

602+
bool outputIsPreallocated() const {
603+
return output_is_preallocated_;
604+
}
605+
606+
Expr* withOutputPreallocated();
607+
602608
// Get the name of an expression
603609
virtual const char* getOpString() const = 0;
604610

@@ -660,6 +666,8 @@ class NVF_API Expr : public Statement {
660666

661667
// Only used for reduction-related expressions
662668
kir::Predicate* write_predicate_ = nullptr;
669+
670+
bool output_is_preallocated_ = false;
663671
};
664672

665673
template <typename T>

tests/cpp/test_host_ir_evaluator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ TEST_F(HostIrEvaluatorTest, MatmulInLoop) {
159159

160160
// By default, MatmulOp is computed by ExpressionEvaluator so it appears in
161161
// host IR.
162-
auto* mm = IrBuilder::create<MatmulOp>(loop_out, in, loop_w);
162+
auto* mm = IrBuilder::create<MatmulOp>(loop_out, in, loop_w)
163+
->withOutputPreallocated();
163164
for_loop->body().push_back(mm);
164165

165166
hic->pushBackTopLevelExprs(allocate_out);

tests/cpp/test_host_ir_stream_lowering.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
*/
77
// clang-format on
88

9-
#include <algorithm>
10-
#include <iostream>
11-
129
#include <gmock/gmock-matchers.h>
1310
#include <gtest/gtest.h>
1411

tests/cpp/test_host_irs.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,8 @@ TEST_F(MatmulHostIrTest, HostIrMatmulOut) {
874874
TensorView* tv0 = makeContigTensor(3);
875875
TensorView* tv1 = makeContigTensor(3);
876876
TensorView* tv2 = makeContigTensor(3);
877-
auto* matmul = IrBuilder::create<MatmulOp>(tv2, tv0, tv1);
877+
auto* matmul =
878+
IrBuilder::create<MatmulOp>(tv2, tv0, tv1)->withOutputPreallocated();
878879

879880
hic->addInput(tv0);
880881
hic->addInput(tv1);
@@ -956,7 +957,8 @@ TEST_F(LinearHostIrTest, HostIrLinearOut) {
956957
TensorView* bias = makeContigTensor(1);
957958
TensorView* out = makeContigTensor(3);
958959

959-
auto linear_op = IrBuilder::create<LinearOp>(out, in, weight, bias);
960+
auto* linear_op = IrBuilder::create<LinearOp>(out, in, weight, bias)
961+
->withOutputPreallocated();
960962

961963
hic->addInput(in);
962964
hic->addInput(weight);

tests/cpp/test_multidevice_overlap.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ TEST_F(StreamTest, RowParallelLinear_Forward) {
7777
constexpr int64_t t = 6;
7878
static_assert(t % s == 0);
7979
at::Tensor in_tensor =
80-
at::randn({t, h * 4}, tensor_options_.dtype(at::kBFloat16));
80+
at::randint(-2, 3, {t, h * 4}, tensor_options_.dtype(at::kBFloat16));
8181
at::Tensor w_tensor =
82-
at::randn({h, h * 4}, tensor_options_.dtype(at::kBFloat16));
82+
at::randint(-2, 3, {h, h * 4}, tensor_options_.dtype(at::kBFloat16));
8383
at::Tensor out_tensor = at::linear(in_tensor, w_tensor);
8484

8585
at::Tensor sharded_in_tensor = shardTensor(in_tensor, in);
@@ -91,7 +91,11 @@ TEST_F(StreamTest, RowParallelLinear_Forward) {
9191
.runFusionWithInputs({sharded_in_tensor, sharded_w_tensor})[0]
9292
.as<at::Tensor>();
9393

94-
EXPECT_TRUE(at::allclose(sharded_out_tensor, out_tensor));
94+
EXPECT_TRUE(at::allclose(sharded_out_tensor, out_tensor))
95+
<< "sharded_out_tensor:" << std::endl
96+
<< sharded_out_tensor << std::endl
97+
<< " out_tensor:" << std::endl
98+
<< out_tensor;
9599
}
96100

97101
} // namespace nvfuser

0 commit comments

Comments
 (0)