Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ci/auto_parallel/ci_auto_parallel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ get_diff_TO_case(){
case_list[${#case_list[*]}]=llama_auto
case_list[${#case_list[*]}]=gpt-3_auto
case_list[${#case_list[*]}]=gpt-3_dygraph
case_list[${#case_list[*]}]=deepseek_auto
}

print_info(){
Expand Down Expand Up @@ -258,6 +259,14 @@ if [[ ${#case_list[*]} -ne 0 ]];then
execute_func_list $cmd gpt-3_dygraph
let case_num++
clean_file ${work_dir}/../PaddleNLP/llm
elif [[ ${case} == "deepseek_auto" ]];then
cmd=${work_dir}/../PaddleNLP/scripts/distribute/ci_case_auto.sh
timeout 5m bash $cmd prepare_case deepseek_case_list_auto $FLAGS_install_deps $FLAGS_download_data
execute_func_list $cmd deepseek_auto
export FLAGS_install_deps=1
export FLAGS_download_data="deepseek ""$FLAGS_download_data"
let case_num++
clean_file ${work_dir}/../PaddleNLP/llm/auto_parallel/deepseek-v3
else
echo -e "\033[31m ---- no ${case} \033"
let case_num++
Expand Down
7 changes: 5 additions & 2 deletions python/paddle/distributed/auto_parallel/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,13 @@ def get_rank2tensor_indices(sub_mesh_indices_info, sub_mesh_partial_info):


def get_local_slices(tensor, mesh, placements):
if len(mesh.shape) != len(placements):
if len(mesh.shape) < len(placements):
raise ValueError(
f"placements nums ({len(placements)}) must equal mesh_shape({len(mesh.shape)})"
f"placements length ({len(placements)}) must be smaller or equal to mesh_shape({len(mesh.shape)})"
)
if len(placements) < len(mesh.shape):
for _ in range(len(mesh.shape) - len(placements)):
placements.append(dist.Replicate())

sub_mesh_indices_info = {mesh: [(0, s) for s in tensor.shape]}
sub_mesh_partial_info = {}
Expand Down
22 changes: 21 additions & 1 deletion test/auto_parallel/semi_auto_parallel_moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import unittest

import numpy as np

Expand All @@ -26,7 +27,7 @@
)


class TestMoEUtils:
class TestMoEUtils(unittest.TestCase):
def __init__(self):
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
Expand Down Expand Up @@ -160,6 +161,25 @@ def test_get_local_slices(self):
dist_x.placements[1].reduce_type(),
)

y = paddle.arange(0, h * w).reshape(src_shape)
y_placements = [dist.Shard(0)]
dist_y = dist.shard_tensor(y, self._mesh0, y_placements)
dist_y_local_slices = get_local_slices(
dist_y, self._mesh0, y_placements
)
np.testing.assert_equal(
dist_y_local_slices[0]['slice'], [(0, 2), (0, 4)]
)
np.testing.assert_equal(
dist_y_local_slices[1]['slice'], [(2, 4), (0, 4)]
)

with self.assertRaises(ValueError):
tmp_placements = [dist.Shard(0), dist.Shard(1), dist.Replicate()]
dist_y_local_slices = get_local_slices(
dist_y, self._mesh0, tmp_placements
)

# python -m paddle.distributed.launch --devices=0,1 semi_auto_parallel_moe_utils.py
def test_reshard_general_case(self):
"""Test reshard when _only_reshard_mesh_shape returns False."""
Expand Down
4 changes: 2 additions & 2 deletions test/auto_parallel/test_moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def setUp(self):
num_of_devices=2,
timeout=30,
)
self._default_envs = {"dtype": "float32", "seed": "2024"}
self._default_envs = {"dtype": "float32", "seeds": "2024"}
self._changeable_envs = {"backend": ["gpu"]}

def test_moe_utils(self):
envs_list = test_base.gen_product_envs_list(
{
"dtype": "float32",
"seed": "2024",
"seeds": "2024",
"FLAGS_enable_moe_utils": "true",
},
{"backend": ["gpu"]},
Expand Down