Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions python/sglang/srt/disaggregation/common/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@
)
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
from sglang.srt.utils import (
format_tcp_address,
get_free_port,
get_ip,
get_local_ip_by_remote,
is_valid_ipv6_address,
maybe_wrap_ipv6_address,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,11 +72,18 @@ def __init__(
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
if self.dist_init_addr.endswith("]"):
host = self.dist_init_addr
else:
host, _ = self.dist_init_addr.rsplit(":", 1)
else:
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
else:
ip_address = get_ip()
host = get_ip()
host = maybe_wrap_ipv6_address(host)

bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
Expand All @@ -92,8 +106,10 @@ def _register_to_bootstrap(self):
logger.error(f"Prefill Failed to register to bootstrap server: {e}")

@cache
def _connect(self, endpoint: str):
def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint)
return socket

Expand Down Expand Up @@ -263,15 +279,27 @@ def _get_prefill_dp_size_from_server(self) -> int:
return None

@classmethod
def _connect(cls, endpoint: str):
def _connect(cls, endpoint: str, is_ipv6: bool = False):
with cls._global_lock:
if endpoint not in cls._socket_cache:
sock = cls._ctx.socket(zmq.PUSH)
if is_ipv6:
sock.setsockopt(zmq.IPV6, 1)
sock.connect(endpoint)
cls._socket_cache[endpoint] = sock
cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]

@classmethod
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
ip_address = bootstrap_info["rank_ip"]
port = bootstrap_info["rank_port"]
is_ipv6_address = is_valid_ipv6_address(ip_address)
sock, lock = cls._connect(
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
)
return sock, lock

def _register_kv_args(self):
pass

Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/disaggregation/mini_lb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from fastapi.responses import ORJSONResponse, Response, StreamingResponse

from sglang.srt.disaggregation.utils import PDRegistryRequest
from sglang.srt.utils import maybe_wrap_ipv6_address

AIOHTTP_STREAM_READ_CHUNK_SIZE = (
1024 * 64
Expand Down Expand Up @@ -271,7 +272,7 @@ async def handle_generate_request(request_data: dict):

# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()

batch_size = _get_request_batch_size(modified_request)
Expand Down Expand Up @@ -309,7 +310,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str):

# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
modified_request.update(
{
Expand Down
69 changes: 49 additions & 20 deletions python/sglang/srt/disaggregation/mooncake/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
from sglang.srt.utils import (
format_tcp_address,
get_free_port,
get_int_env_var,
get_ip,
get_local_ip_auto,
is_valid_ipv6_address,
maybe_wrap_ipv6_address,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -148,6 +156,9 @@ def __init__(
self.request_status: Dict[int, KVPoll] = {}
self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL)
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)

self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
Expand Down Expand Up @@ -240,8 +251,10 @@ def register_buffer_to_engine(self):
self.engine.register(aux_data_ptr, aux_data_len)

@cache
def _connect(self, endpoint: str):
def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint)
return socket

Expand Down Expand Up @@ -483,9 +496,9 @@ def send_aux(
def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
):
if ":" in remote:
remote = remote.split(":")[0]
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
self._connect(
format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
).send_multipart(
[
str(room).encode("ascii"),
str(status).encode("ascii"),
Expand Down Expand Up @@ -628,9 +641,12 @@ def transfer_worker(
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
)

def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))

def start_prefill_thread(self):
self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
self._bind_server_socket()

def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine"""
Expand Down Expand Up @@ -669,7 +685,7 @@ def bootstrap_thread():

def start_decode_thread(self):
self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
self._bind_server_socket()

def decode_thread():
while True:
Expand Down Expand Up @@ -788,7 +804,7 @@ def add_transfer_request(
# requests with the same dst_sessions will be added into the same
# queue, which enables early abort with failed sessions.
dst_infos = self.transfer_infos[bootstrap_room].keys()
session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos)
session_port_sum = sum(int(session.rsplit(":", 1)[1]) for session in dst_infos)
shard_idx = session_port_sum % len(self.transfer_queues)

self.transfer_queues[shard_idx].put(
Expand Down Expand Up @@ -826,11 +842,18 @@ def get_session_id(self):
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
if self.dist_init_addr.endswith("]"):
host = self.dist_init_addr
else:
host, _ = self.dist_init_addr.rsplit(":", 1)
else:
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
else:
ip_address = get_ip()
host = get_ip()
host = maybe_wrap_ipv6_address(host)

bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
Expand Down Expand Up @@ -1175,9 +1198,6 @@ def _get_prefill_parallel_info_from_server(self) -> Tuple[int, int]:

def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
Expand All @@ -1191,7 +1211,7 @@ def _register_kv_args(self):
dst_tp_size = str(tp_size).encode("ascii")
dst_kv_item_len = str(kv_item_len).encode("ascii")

sock, lock = self._connect("tcp://" + self.prefill_server_url)
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
with lock:
sock.send_multipart(
[
Expand All @@ -1208,23 +1228,32 @@ def _register_kv_args(self):
)

@classmethod
def _connect(cls, endpoint: str):
def _connect(cls, endpoint: str, is_ipv6: bool = False):
with cls._global_lock:
if endpoint not in cls._socket_cache:
sock = cls._ctx.socket(zmq.PUSH)
if is_ipv6:
sock.setsockopt(zmq.IPV6, 1)
sock.connect(endpoint)
cls._socket_cache[endpoint] = sock
cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]

@classmethod
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
ip_address = bootstrap_info["rank_ip"]
port = bootstrap_info["rank_port"]
is_ipv6_address = is_valid_ipv6_address(ip_address)
sock, lock = cls._connect(
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
)
return sock, lock

def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
is_dummy = bootstrap_info["is_dummy"]

sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import List, Optional

from sglang.srt.utils import get_bool_env_var, get_free_port
from sglang.srt.utils import get_bool_env_var, get_free_port, maybe_wrap_ipv6_address

logger = logging.getLogger(__name__)

Expand All @@ -27,7 +27,9 @@ def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
hostname=self.hostname,
device_name=self.ib_device,
)
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
self.session_id = (
f"{maybe_wrap_ipv6_address(self.hostname)}:{self.engine.get_rpc_port()}"
)

def register(self, ptr, length):
try:
Expand Down
30 changes: 17 additions & 13 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_local_ip_by_remote
from sglang.srt.utils import (
format_tcp_address,
get_local_ip_auto,
is_valid_ipv6_address,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -124,7 +128,10 @@ def __init__(
"to run SGLang with NixlTransferEngine."
) from e
self.agent = nixl_agent(str(uuid.uuid4()))
self.local_ip = get_local_ip_auto()
self.server_socket = zmq.Context().socket(zmq.PULL)
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)
self.register_buffer_to_engine()

if self.disaggregation_mode == DisaggregationMode.PREFILL:
Expand Down Expand Up @@ -337,8 +344,11 @@ def check_transfer_done(self, room: int):
return False
return self.transfer_statuses[room].is_done()

def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))

def _start_bootstrap_thread(self):
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
self._bind_server_socket()

def bootstrap_thread():
"""This thread recvs transfer info from the decode engine"""
Expand Down Expand Up @@ -452,23 +462,20 @@ def __init__(

def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
is_dummy = bootstrap_info["is_dummy"]
logger.debug(
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}"
)
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
GUARD,
str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.name.encode("ascii"),
kv_indices.tobytes() if not is_dummy else b"",
Expand All @@ -494,23 +501,20 @@ def poll(self) -> KVPoll:

def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)

sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
GUARD,
"None".encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.name.encode("ascii"),
self.kv_mgr.agent.get_agent_metadata(),
Expand Down
Loading
Loading