Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/memory/allocation/buddy_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it,

VLOG(10) << "Split block (" << block << ", " << desc->get_total_size()
<< ") into";
block->Split(&cache_, size);
block->Split(&cache_, size, extra_padding_size_);

VLOG(10) << "Left block (" << block << ", " << desc->get_total_size() << ")";
desc->set_type(MemoryBlock::ARENA_CHUNK);
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/memory/allocation/memory_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ MemoryBlock* MemoryBlock::GetRightBuddy(MetadataCache* cache) {
return cache->LoadDesc(this)->right_buddy;
}

void MemoryBlock::Split(MetadataCache* cache, size_t size) {
void MemoryBlock::Split(MetadataCache* cache,
size_t size,
size_t extra_padding_size) {
auto desc = cache->LoadDesc(this);
// make sure the split fits
PADDLE_ENFORCE_GE(desc->total_size,
Expand All @@ -55,7 +57,8 @@ void MemoryBlock::Split(MetadataCache* cache, size_t size) {
size));

// bail out if there is no room for another partition
if (desc->total_size - size <= sizeof(MemoryBlock::Desc)) {
if (desc->total_size - size <=
(sizeof(MemoryBlock::Desc) + extra_padding_size)) {
return;
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/memory/allocation/memory_block.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ struct MemoryBlock {
MemoryBlock* GetRightBuddy(MetadataCache* cache);

// Split the allocation into left/right blocks.
void Split(MetadataCache* cache, size_t size);
void Split(MetadataCache* cache, size_t size, size_t extra_padding_size = 0);

// Merge left and right blocks together.
void Merge(MetadataCache* cache, MemoryBlock* right_buddy);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/backends/custom/custom_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ class CustomDevice : public DeviceInterface {
return_result(phi::DataType::FLOAT64, FLOAT64);
return_result(phi::DataType::FLOAT32, FLOAT32);
return_result(phi::DataType::FLOAT16, FLOAT16);
return_result(phi::DataType::BFLOAT16, BFLOAT16);
return_result(phi::DataType::INT64, INT64);
return_result(phi::DataType::INT32, INT32);
return_result(phi::DataType::INT16, INT16);
Expand Down
16 changes: 8 additions & 8 deletions test/custom_runtime/process_group_xccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_create_process_group_xccl(self):
task.wait()
# assert np.array_equal(tensor_y, sum_result)

print("test allreduce sum api ok")
print("test allreduce sum api ok", flush=True)

x = np.random.random(self.shape).astype(self.dtype)
tensor_x = paddle.to_tensor(x)
Expand All @@ -86,7 +86,7 @@ def test_create_process_group_xccl(self):
task.wait()
# assert np.array_equal(tensor_y, max_result)

print("test allreduce max api ok")
print("test allreduce max api ok", flush=True)

# test broadcast
# rank 0
Expand All @@ -110,7 +110,7 @@ def test_create_process_group_xccl(self):
assert task.is_completed()
# assert np.array_equal(broadcast_result, tensor_y)

print("test broadcast api ok")
print("test broadcast api ok", flush=True)

# test barrier
# rank 0
Expand All @@ -122,7 +122,7 @@ def test_create_process_group_xccl(self):
task = pg.barrier(device_id)
task.wait()

print("test barrier api ok\n")
print("test barrier api ok\n", flush=True)
return

# test allgather
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_create_process_group_xccl(self):
)
# assert np.array_equal(tensor_x, out_1)
# assert np.array_equal(tensor_y, out_2)
print("test allgather api ok\n")
print("test allgather api ok\n", flush=True)

# test alltoall
# rank 0
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_create_process_group_xccl(self):
# assert np.array_equal(out1_2.numpy(), raw_tensor_y_1.numpy())
# else:
# assert np.array_equal(out2_1, raw_tensor_x_2)
print("test alltoall api ok\n")
print("test alltoall api ok\n", flush=True)

# test Reduce
# rank 0
Expand All @@ -203,7 +203,7 @@ def test_create_process_group_xccl(self):
# paddle.base.core._custom_device_synchronize("custom_cpu", -1)
# if pg.rank() == 0:
# assert np.array_equal(tensor_x, sum_result)
print("test reduce sum api ok\n")
print("test reduce sum api ok\n", flush=True)

# test Scatter
# rank 0
Expand All @@ -228,7 +228,7 @@ def test_create_process_group_xccl(self):
# assert np.array_equal(tensor_y, out1)
# else:
# assert np.array_equal(tensor_y, out2)
print("test scatter api ok\n")
print("test scatter api ok\n", flush=True)


if __name__ == "__main__":
Expand Down