Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion sgl-router/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ python -m sglang_router.launch_router \
--prefill-selector app=sglang component=prefill \
--decode-selector app=sglang component=decode \
--service-discovery-namespace sglang-system

# With separate routing policies:
python -m sglang_router.launch_router \
--pd-disaggregation \
--prefill-policy cache_aware \
--decode-policy power_of_two \
--service-discovery \
--prefill-selector app=sglang component=prefill \
--decode-selector app=sglang component=decode \
--service-discovery-namespace sglang-system
```

#### Kubernetes Pod Configuration
Expand Down Expand Up @@ -226,7 +236,9 @@ python -m sglang_router.launch_router \
- `--decode`: Initial decode server URL
- `--prefill-selector`: Label selector for prefill pods
- `--decode-selector`: Label selector for decode pods
- `--policy`: Routing policy (`cache_aware`, `random`, `power_of_two`)
- `--policy`: Routing policy (`cache_aware`, `random`, `power_of_two`, `round_robin`)
- `--prefill-policy`: Separate routing policy for prefill nodes (optional, overrides `--policy` for prefill)
- `--decode-policy`: Separate routing policy for decode nodes (optional, overrides `--policy` for decode)

## Development

Expand Down
67 changes: 65 additions & 2 deletions sgl-router/py_src/sglang_router/launch_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class RouterArgs:

# Routing policy
policy: str = "cache_aware"
prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode
decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode
worker_startup_timeout_secs: int = 300
worker_startup_check_interval: int = 10
cache_threshold: float = 0.5
Expand Down Expand Up @@ -108,7 +110,21 @@ def add_cli_args(
type=str,
default=RouterArgs.policy,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden",
)
parser.add_argument(
f"--{prefix}prefill-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy",
)
parser.add_argument(
f"--{prefix}decode-policy",
type=str,
default=None,
choices=["random", "round_robin", "cache_aware", "power_of_two"],
help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy",
)

# PD-specific arguments
Expand Down Expand Up @@ -266,6 +282,8 @@ def from_cli_args(
prefill_urls=prefill_urls,
decode_urls=decode_urls,
policy=getattr(args, f"{prefix}policy"),
prefill_policy=getattr(args, f"{prefix}prefill_policy", None),
decode_policy=getattr(args, f"{prefix}decode_policy", None),
worker_startup_timeout_secs=getattr(
args, f"{prefix}worker_startup_timeout_secs"
),
Expand Down Expand Up @@ -389,6 +407,35 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
if not router_args.decode_urls:
raise ValueError("PD disaggregation mode requires --decode")

# Warn about policy usage in PD mode
if (
router_args.prefill_policy
and router_args.decode_policy
and router_args.policy
):
logger.warning(
"Both --prefill-policy and --decode-policy are specified. "
"The main --policy flag will be ignored for PD mode."
)
elif (
router_args.prefill_policy
and not router_args.decode_policy
and router_args.policy
):
logger.info(
f"Using --prefill-policy '{router_args.prefill_policy}' for prefill nodes "
f"and --policy '{router_args.policy}' for decode nodes."
)
elif (
router_args.decode_policy
and not router_args.prefill_policy
and router_args.policy
):
logger.info(
f"Using --policy '{router_args.policy}' for prefill nodes "
f"and --decode-policy '{router_args.decode_policy}' for decode nodes."
)

# Create router with unified constructor
router = Router(
worker_urls=(
Expand Down Expand Up @@ -424,6 +471,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
decode_urls=(
router_args.decode_urls if router_args.pd_disaggregation else None
),
prefill_policy=(
policy_from_str(router_args.prefill_policy)
if router_args.prefill_policy
else None
),
decode_policy=(
policy_from_str(router_args.decode_policy)
if router_args.decode_policy
else None
),
)

router.start()
Expand Down Expand Up @@ -455,12 +512,18 @@ def parse_router_args(args: List[str]) -> RouterArgs:
# Regular mode
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000

# PD disaggregated mode
# PD disaggregated mode with same policy for both
python -m sglang_router.launch_router --pd-disaggregation \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--policy cache_aware

# PD mode with different policies for prefill and decode
python -m sglang_router.launch_router --pd-disaggregation \\
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
--decode http://decode1:8001 --decode http://decode2:8001 \\
--prefill-policy cache_aware --decode-policy power_of_two

""",
formatter_class=CustomHelpFormatter,
)
Expand Down
8 changes: 8 additions & 0 deletions sgl-router/py_src/sglang_router/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class Router:
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
decode_urls: List of URLs for decode servers (PD mode only)
prefill_policy: Specific load balancing policy for prefill nodes (PD mode only).
If not specified, uses the main policy. Default: None
decode_policy: Specific load balancing policy for decode nodes (PD mode only).
If not specified, uses the main policy. Default: None
"""

def __init__(
Expand Down Expand Up @@ -79,6 +83,8 @@ def __init__(
pd_disaggregation: bool = False,
prefill_urls: Optional[List[tuple]] = None,
decode_urls: Optional[List[str]] = None,
prefill_policy: Optional[PolicyType] = None,
decode_policy: Optional[PolicyType] = None,
):
if selector is None:
selector = {}
Expand Down Expand Up @@ -113,6 +119,8 @@ def __init__(
pd_disaggregation=pd_disaggregation,
prefill_urls=prefill_urls,
decode_urls=decode_urls,
prefill_policy=prefill_policy,
decode_policy=decode_policy,
)

def start(self) -> None:
Expand Down
Loading
Loading