Skip to content

Commit 82eb4cc

Browse files
committed
Fix MultiDeviceExecutor
1 parent eddca9e commit 82eb4cc

File tree

5 files changed

+16
-12
lines changed

5 files changed

+16
-12
lines changed

csrc/host_ir/evaluator.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -482,14 +482,15 @@ void HostIrEvaluator::handle(MatmulOp* matmul) {
482482
TensorView* b = matmul->inB();
483483
TensorView* out = matmul->out();
484484

485-
if (expr_evaluator_.isKnown(out)) {
486-
auto t_a = getKnownConcreteValue(a).as<at::Tensor>();
487-
auto t_b = getKnownConcreteValue(b).as<at::Tensor>();
488-
auto t_out = getKnownConcreteValue(out).as<at::Tensor>();
489-
at::matmul_out(t_out, t_a, t_b);
490-
} else {
485+
if (!matmul->outputIsPreallocated()) {
491486
unhandled(matmul);
487+
return;
492488
}
489+
490+
auto t_a = getKnownConcreteValue(a).as<at::Tensor>();
491+
auto t_b = getKnownConcreteValue(b).as<at::Tensor>();
492+
auto t_out = getKnownConcreteValue(out).as<at::Tensor>();
493+
at::matmul_out(t_out, t_a, t_b);
493494
}
494495

495496
void HostIrEvaluator::handle(LinearOp* linear) {
@@ -498,6 +499,8 @@ void HostIrEvaluator::handle(LinearOp* linear) {
498499
auto* out = linear->out()->as<TensorView>();
499500

500501
// FIXME: this breaks MultiDeviceExecutor.
502+
std::cout << "linear->outputIsPreallocated(): "
503+
<< linear->outputIsPreallocated() << std::endl;
501504
if (!linear->outputIsPreallocated()) {
502505
return unhandled(linear);
503506
}

csrc/host_ir/pass/stream_parallel_type.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,7 @@ std::list<Expr*> processForLoopBodies(
475475
ir_utils::filterByType<TensorView>(body_expr->outputs())) {
476476
processTensor(body_expr, output, tensor_index);
477477
}
478+
body_expr = body_expr->withOutputPreallocated();
478479
new_loop_body.push_back(body_expr);
479480
}
480481
}

tests/cpp/test_host_ir_evaluator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ TEST_F(HostIrEvaluatorTest, MatmulInLoop) {
159159

160160
// By default, MatmulOp is computed by ExpressionEvaluator so it appears in
161161
// host IR.
162-
auto* mm = IrBuilder::create<MatmulOp>(loop_out, in, loop_w);
162+
auto* mm = IrBuilder::create<MatmulOp>(loop_out, in, loop_w)
163+
->withOutputPreallocated();
163164
for_loop->body().push_back(mm);
164165

165166
hic->pushBackTopLevelExprs(allocate_out);

tests/cpp/test_host_ir_stream_lowering.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
*/
77
// clang-format on
88

9-
#include <algorithm>
10-
#include <iostream>
11-
129
#include <gmock/gmock-matchers.h>
1310
#include <gtest/gtest.h>
1411

tests/cpp/test_host_irs.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,8 @@ TEST_F(MatmulHostIrTest, HostIrMatmulOut) {
874874
TensorView* tv0 = makeContigTensor(3);
875875
TensorView* tv1 = makeContigTensor(3);
876876
TensorView* tv2 = makeContigTensor(3);
877-
auto* matmul = IrBuilder::create<MatmulOp>(tv2, tv0, tv1);
877+
auto* matmul =
878+
IrBuilder::create<MatmulOp>(tv2, tv0, tv1)->withOutputPreallocated();
878879

879880
hic->addInput(tv0);
880881
hic->addInput(tv1);
@@ -956,7 +957,8 @@ TEST_F(LinearHostIrTest, HostIrLinearOut) {
956957
TensorView* bias = makeContigTensor(1);
957958
TensorView* out = makeContigTensor(3);
958959

959-
auto linear_op = IrBuilder::create<LinearOp>(out, in, weight, bias);
960+
auto* linear_op = IrBuilder::create<LinearOp>(out, in, weight, bias)
961+
->withOutputPreallocated();
960962

961963
hic->addInput(in);
962964
hic->addInput(weight);

0 commit comments

Comments
 (0)