Skip to content

Commit eddca9e

Browse files
committed
WIP
1 parent 5230260 commit eddca9e

File tree

5 files changed

+43
-15
lines changed

5 files changed

+43
-15
lines changed

csrc/host_ir/evaluator.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -497,12 +497,9 @@ void HostIrEvaluator::handle(LinearOp* linear) {
497497
auto* weight = linear->inB()->as<TensorView>();
498498
auto* out = linear->out()->as<TensorView>();
499499

500-
// FIXME: When LinearOp is called in a for loop, even if it's output is not
501-
// pre-allocated, the second iteration will see isKnown true and skip the
502-
// unhandled path.
503-
if (!expr_evaluator_.isKnown(out)) {
504-
unhandled(linear);
505-
return;
500+
// FIXME: this breaks MultiDeviceExecutor.
501+
if (!linear->outputIsPreallocated()) {
502+
return unhandled(linear);
506503
}
507504

508505
auto in_tensor = getKnownConcreteValue(in).as<at::Tensor>();

csrc/host_ir/lowering.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ const std::vector<IterDomain*>& findReferenceLoopDomain(
129129

130130
Expr* cloneWithNewOperands(
131131
Expr* e,
132-
const std::unordered_map<Val*, Val*>& replacement_map) {
132+
const std::unordered_map<Val*, Val*>& replacement_map,
133+
bool output_is_preallocated) {
134+
NVF_ERROR(!e->outputIsPreallocated());
135+
133136
auto maybe_replace = [&](Val*& x) -> bool {
134137
Val* new_x = getOrDefault(replacement_map, x);
135138
if (new_x == nullptr) {
@@ -147,10 +150,16 @@ Expr* cloneWithNewOperands(
147150
std::vector<Val*> new_outs = e->outputs();
148151
replaced += std::ranges::count_if(new_outs, maybe_replace);
149152

150-
if (replaced == 0) {
153+
if (replaced == 0 && !output_is_preallocated) {
151154
return e;
152155
}
153-
return e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes());
156+
157+
Expr* new_e =
158+
e->newObjectFunc()(e->container(), new_ins, new_outs, e->attributes());
159+
if (output_is_preallocated) {
160+
new_e = new_e->withOutputPreallocated();
161+
}
162+
return new_e;
154163
}
155164

156165
void lowerSegment(
@@ -213,7 +222,7 @@ void lowerSegment(
213222
innermost_scope.push_back(allocate);
214223
}
215224

216-
Expr* new_c = cloneWithNewOperands(c, replacement_map);
225+
Expr* new_c = cloneWithNewOperands(c, replacement_map, true);
217226
innermost_scope.push_back(new_c);
218227

219228
auto* wait = IrBuilder::create<hir::Wait>(new_c);
@@ -267,10 +276,12 @@ void lowerSegment(
267276
}
268277
}
269278

279+
bool output_is_preallocated = false;
270280
for (auto* out : ir_utils::filterByType<TensorView>(e->outputs())) {
271281
if (getShardedIterDomain(out, ParallelType::Stream) == nullptr) {
272282
auto* allocate =
273283
IrBuilder::create<kir::Allocate>(out, MemoryType::Global);
284+
output_is_preallocated = true;
274285
innermost.parent_scope->insert(
275286
innermost.parent_insertion_point, allocate);
276287
// Loop is stream parallelized but allocation is not. Therefore,
@@ -285,7 +296,8 @@ void lowerSegment(
285296
}
286297
}
287298

288-
Expr* new_e = cloneWithNewOperands(e, replacement_map);
299+
Expr* new_e =
300+
cloneWithNewOperands(e, replacement_map, output_is_preallocated);
289301
innermost_scope.push_back(new_e);
290302
}
291303
break;

csrc/ir/base_nodes.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ std::optional<DataType> Val::getDataType() const {
253253
// after inputs and outputs are registered with the Expr
254254
Expr::Expr(IrBuilderPasskey passkey) : Statement(passkey) {}
255255

256+
// FIXME: Should this constructor copy the output_is_preallocated_ flag?
256257
Expr::Expr(const Expr* src, IrCloner* ir_cloner)
257258
: Statement(src, ir_cloner),
258259
attributes_(ir_cloner->clone(src->attributes_)),
@@ -270,12 +271,13 @@ Expr::Expr(
270271
outputs_(std::move(outputs)) {}
271272

272273
Expr* Expr::shallowCopy() const {
273-
auto result =
274+
Expr* result =
274275
newObjectFunc()(ir_container_, inputs(), outputs(), attributes());
275276
if (container()->isA<kir::Kernel>()) {
276277
result->predicate_ = predicate_;
277278
result->write_predicate_ = write_predicate_;
278279
}
280+
result->output_is_preallocated_ = output_is_preallocated_;
279281
return result;
280282
}
281283

@@ -383,6 +385,11 @@ Expr* Expr::withWritePredicate(kir::Predicate* predicate) {
383385
return result;
384386
}
385387

388+
Expr* Expr::withOutputPreallocated() {
389+
output_is_preallocated_ = true;
390+
return this;
391+
}
392+
386393
std::vector<PolymorphicValue> Expr::evaluate(
387394
const ExpressionEvaluator& ee,
388395
const std::vector<PolymorphicValue>& inputs) const {

csrc/ir/base_nodes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,12 @@ class NVF_API Expr : public Statement {
599599
// TODO: Protect based on being in kernel container
600600
Expr* withWritePredicate(kir::Predicate* write_predicate);
601601

602+
bool outputIsPreallocated() const {
603+
return output_is_preallocated_;
604+
}
605+
606+
Expr* withOutputPreallocated();
607+
602608
// Get the name of an expression
603609
virtual const char* getOpString() const = 0;
604610

@@ -660,6 +666,8 @@ class NVF_API Expr : public Statement {
660666

661667
// Only used for reduction-related expressions
662668
kir::Predicate* write_predicate_ = nullptr;
669+
670+
bool output_is_preallocated_ = false;
663671
};
664672

665673
template <typename T>

tests/cpp/test_multidevice_overlap.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ TEST_F(StreamTest, RowParallelLinear_Forward) {
7777
constexpr int64_t t = 6;
7878
static_assert(t % s == 0);
7979
at::Tensor in_tensor =
80-
at::randn({t, h * 4}, tensor_options_.dtype(at::kBFloat16));
80+
at::randint(-2, 3, {t, h * 4}, tensor_options_.dtype(at::kBFloat16));
8181
at::Tensor w_tensor =
82-
at::randn({h, h * 4}, tensor_options_.dtype(at::kBFloat16));
82+
at::randint(-2, 3, {h, h * 4}, tensor_options_.dtype(at::kBFloat16));
8383
at::Tensor out_tensor = at::linear(in_tensor, w_tensor);
8484

8585
at::Tensor sharded_in_tensor = shardTensor(in_tensor, in);
@@ -91,7 +91,11 @@ TEST_F(StreamTest, RowParallelLinear_Forward) {
9191
.runFusionWithInputs({sharded_in_tensor, sharded_w_tensor})[0]
9292
.as<at::Tensor>();
9393

94-
EXPECT_TRUE(at::allclose(sharded_out_tensor, out_tensor));
94+
EXPECT_TRUE(at::allclose(sharded_out_tensor, out_tensor))
95+
<< "sharded_out_tensor:" << std::endl
96+
<< sharded_out_tensor << std::endl
97+
<< " out_tensor:" << std::endl
98+
<< out_tensor;
9599
}
96100

97101
} // namespace nvfuser

0 commit comments

Comments
 (0)