Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 158 additions & 18 deletions test/srt/test_disaggregation_different_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import unittest
from types import SimpleNamespace
from urllib.parse import urlparse

import requests

Expand All @@ -18,23 +19,30 @@
)


class TestDisaggregationMooncakeDifferentTP(CustomTestCase):
class TestDisaggregationMooncakePrefillLargerTP(CustomTestCase):
@classmethod
def setUpClass(cls):
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"

cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_host = "127.0.0.1"
cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1])
cls.lb_url = DEFAULT_URL_FOR_TEST
cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}"
cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}"
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")

run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
# Non blocking start servers
cls.start_prefill()
cls.start_decode()

# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")

Expand All @@ -49,7 +57,7 @@ def setUpClass(cls):
"--host",
cls.base_host,
"--port",
str(cls.base_port),
cls.lb_port,
]

print("Starting load balancer:", " ".join(lb_command))
Expand All @@ -64,12 +72,10 @@ def start_prefill(cls):
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--host",
cls.base_host,
"--port",
str(cls.base_port + 100),
"--tp",
"4",
"2",
"--disaggregation-ib-device",
"mlx5_roce0,mlx5_roce1",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
Expand All @@ -84,14 +90,146 @@ def start_decode(cls):
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"1",
"--base-gpu-id",
"2",
"--disaggregation-ib-device",
"mlx5_roce2",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)

@classmethod
def wait_server_ready(cls, url, timeout=60):
start_time = time.perf_counter()
while True:
try:
response = requests.get(url)
if response.status_code == 200:
print(f"Server {url} is ready")
return
except Exception:
pass

if time.perf_counter() - start_time > timeout:
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
time.sleep(1)

@classmethod
def tearDownClass(cls):
# Restore JIT DeepGEMM environment variable
if cls.original_jit_deepgemm is not None:
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = cls.original_jit_deepgemm
else:
os.environ.pop("SGL_ENABLE_JIT_DEEPGEMM", None)

for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
if process:
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")

self.assertGreater(metrics["accuracy"], 0.60)


class TestDisaggregationMooncakeDecodeLargerTP(CustomTestCase):
@classmethod
def setUpClass(cls):
# Temporarily disable JIT DeepGEMM
cls.original_jit_deepgemm = os.environ.get("SGL_ENABLE_JIT_DEEPGEMM")
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"

cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
cls.base_host = parsed_url.hostname
base_port = str(parsed_url.port)
cls.lb_port = base_port
cls.prefill_port = f"{int(base_port) + 100}"
cls.decode_port = f"{int(base_port) + 200}"
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")

# Non blocking start servers
cls.start_prefill()
cls.start_decode()

# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")

lb_command = [
"python3",
"-m",
"sglang.srt.disaggregation.mini_lb",
"--prefill",
cls.prefill_url,
"--decode",
cls.decode_url,
"--host",
cls.base_host,
"--port",
str(cls.base_port + 200),
cls.lb_port,
]

print("Starting load balancer:", " ".join(lb_command))
cls.process_lb = subprocess.Popen(
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
cls.wait_server_ready(cls.lb_url + "/health")

@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp",
"1",
"--disaggregation-ib-device",
"mlx5_roce0",
]
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)

@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"2",
"--base-gpu-id",
"4",
"1",
"--disaggregation-ib-device",
"mlx5_roce1,mlx5_roce2",
]
cls.process_decode = popen_launch_pd_server(
cls.model,
Expand Down Expand Up @@ -130,6 +268,8 @@ def tearDownClass(cls):
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process {process.pid}: {e}")
# wait for 5 seconds
time.sleep(5)

def test_gsm8k(self):
args = SimpleNamespace(
Expand All @@ -138,8 +278,8 @@ def test_gsm8k(self):
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.lb_url.split(":")[-1]),
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
Expand Down
Loading