Skip to content

Commit 3ed70b2

Browse files
authored
Fix buffer reusing (#2490)
1 parent 08dc16d commit 3ed70b2

File tree

3 files changed

+18
-15
lines changed

3 files changed

+18
-15
lines changed

third_party/nvfuser/csrc/kernel_ir.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,7 @@ Allocate::Allocate(
157157
TORCH_INTERNAL_ASSERT(alias->memoryType() == memory_type, "Invalid alias");
158158
}
159159

160-
// FIXME: there is a bug in lower_alias_memory.cpp that causes
161-
// `NVFuserTest.FusionPredicateElimination6_CUDA` to fail if I simplify `5*2`
162-
// into `10`
163-
164-
// size = simplifyExpr(size);
160+
size = simplifyExpr(size);
165161

166162
addInput(size);
167163
addAttribute(buffer);

third_party/nvfuser/csrc/lower_alias_memory.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -321,18 +321,22 @@ class BufferReuseDebugPrinter {
321321
//! The first write and last read
322322
//! is based on the position on the linear order within
323323
//! the Kernel IR.
324-
//! The interval is semi-open,
325-
//! i.e. [First_Write, Last_Read)
326-
//! So the buffer is NOT available at exactly First_Write
327-
//! position while it IS available at Last_Read.
324+
//! The interval is closed,
325+
//! i.e. [First_Write, Last_Read]
326+
//! So the buffer is NOT available from First_Write to
327+
//! Last_Read position. For the case where First_Write
328+
//! and Last_Read are identical, we can actually reuse
329+
//! buffer if the read and write has exactly the same
330+
//! index, however, for simplicity, we are not taking
331+
//! advantage of this opportunity yet.
328332
class BufferLiveInterval {
329333
public:
330334
// Simple detection of intersection of two intervals
331335
bool intersect(BufferLiveInterval* other) {
332336
if (first_write_pos_ <= other->first_write_pos_) {
333-
return other->first_write_pos_ < last_read_pos_;
337+
return other->first_write_pos_ <= last_read_pos_;
334338
} else {
335-
return first_write_pos_ < other->last_read_pos_;
339+
return first_write_pos_ <= other->last_read_pos_;
336340
}
337341
}
338342

third_party/nvfuser/test/test_gpu2.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8255,18 +8255,21 @@ TEST_F(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) {
82558255
auto tv1 = add(tv0, IrBuilder::create<Double>(1));
82568256
auto tv2 = add(tv1, IrBuilder::create<Double>(1));
82578257
auto tv3 = add(tv2, IrBuilder::create<Double>(1));
8258+
auto tv4 = add(tv3, IrBuilder::create<Double>(1));
82588259

8259-
fusion.addOutput(tv3);
8260+
fusion.addOutput(tv4);
82608261

82618262
tv1->setMemoryType(MemoryType::Shared);
82628263
tv2->setMemoryType(MemoryType::Shared);
8264+
tv3->setMemoryType(MemoryType::Shared);
82638265

8264-
tv3->split(0, 4);
8265-
tv0->computeAt(tv3, 1);
8266+
tv4->split(0, 4);
8267+
tv0->computeAt(tv4, 1);
82668268

82678269
tv1->axis(-1)->parallelize(ParallelType::TIDx);
82688270
tv2->axis(-1)->parallelize(ParallelType::TIDy);
82698271
tv3->axis(-1)->parallelize(ParallelType::TIDz);
8272+
tv4->axis(-1)->parallelize(ParallelType::TIDx);
82708273

82718274
// Make sure a WAR sync is inserted at the end of the outer loop
82728275
GpuLower gpulw(&fusion);
@@ -8291,7 +8294,7 @@ TEST_F(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) {
82918294
fe.compileFusion(&fusion, aten_inputs);
82928295
auto outputs = fe.runFusion(aten_inputs);
82938296

8294-
auto ref1 = t0 + 3;
8297+
auto ref1 = t0 + 4;
82958298

82968299
testValidate(&fusion, outputs, aten_inputs, {ref1}, __LINE__, __FILE__);
82978300
}

0 commit comments

Comments
 (0)