Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions fastdeploy/cache_manager/cache_messager.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ def __init__(
self.gpu_cache_kvs = gpu_cache_kvs
self.rank = rank
self.nranks = nranks
address = (pod_ip, engine_worker_queue_port)
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
address = (pod_ip, engine_worker_queue_port)
else:
address = f"/dev/shm/fd_task_queue_{engine_worker_queue_port}.sock"
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
Expand Down Expand Up @@ -423,7 +426,10 @@ def __init__(
self.gpu_cache_kvs = gpu_cache_kvs
self.rank = rank
self.nranks = nranks
address = (pod_ip, engine_worker_queue_port)
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
address = (pod_ip, engine_worker_queue_port)
else:
address = f"/dev/shm/fd_task_queue_{engine_worker_queue_port}.sock"
self.engine_worker_queue = EngineWorkerQueue(
address=address,
is_server=False,
Expand Down
11 changes: 7 additions & 4 deletions fastdeploy/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,10 +923,13 @@ def launch_components(self):
1,
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
):
address = (
self.cfg.master_ip,
int(self.cfg.parallel_config.engine_worker_queue_port[i]),
)
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
address = (
self.cfg.master_ip,
int(self.cfg.parallel_config.engine_worker_queue_port[i]),
)
else:
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock"
llm_logger.info(f"dp start queue service {address}")
self.dp_engine_worker_queue_server.append(
EngineWorkerQueue(
Expand Down
35 changes: 22 additions & 13 deletions fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,16 @@ def start_worker_queue_service(self, start_queue):
"""
start queue service for engine worker communication
"""
address = (
self.cfg.master_ip,
int(self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]),
)

if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
address = (
self.cfg.master_ip,
int(
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
),
)
else:
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]}.sock"

if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"):
self.llm_logger.info(f"Starting engine worker queue server service at {address}")
Expand All @@ -284,15 +290,18 @@ def start_worker_queue_service(self, start_queue):
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
)
# Dynamically updates the port value if an anonymous port is used
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] = str(
self.engine_worker_queue_server.get_server_port()
)
address = (
self.cfg.master_ip,
int(
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]
),
)
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] = (
str(self.engine_worker_queue_server.get_server_port())
)
address = (
self.cfg.master_ip,
int(
self.cfg.parallel_config.engine_worker_queue_port[
self.cfg.parallel_config.local_data_parallel_id
]
),
)

if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed":
self.cache_task_queue = EngineCacheQueue(
Expand Down
16 changes: 11 additions & 5 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,9 @@ def launch_components(self):
for i in range(self.cfg.parallel_config.data_parallel_size):
request_queues_for_dp_ipc.append(multiprocessing.Queue())
self.engine.scheduler.start(
self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc
self.cfg.node_rank * self.cfg.worker_num_per_node % self.cfg.worker_num_per_node,
request_queues_for_dp_ipc,
result_queue_for_dp_ipc,
)

if not envs.FD_ENABLE_MULTI_API_SERVER:
Expand All @@ -719,10 +721,14 @@ def launch_components(self):
1,
self.cfg.parallel_config.data_parallel_size // self.cfg.nnode,
):
address = (
self.cfg.master_ip,
int(self.cfg.parallel_config.engine_worker_queue_port[i]),
)
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
address = (
self.cfg.master_ip,
int(self.cfg.parallel_config.engine_worker_queue_port[i]),
)
else:
address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock"

llm_logger.info(f"dp start queue service {address}")
self.dp_engine_worker_queue_server.append(
EngineWorkerQueue(
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
self.extend_block_tables = []
# dp
self.dp_rank = dp_rank
self.llm_engine_recv_req_timestamp = time.time()

@classmethod
def from_dict(cls, d: dict):
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@
"ENCODE_FEATURE_BOS_SK": lambda: os.getenv("ENCODE_FEATURE_BOS_SK"),
# Enable offline perf test mode for PD disaggregation
"FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")),
"FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")),
"FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")),
# ep+tp strategy: "all_reduce" or "all_to_all"
# all_reduce: qkv_linear + attn + out_linear + allreduce
# all_to_all: allgather + qkv_linear + attn + all2all + out_linear
Expand Down
38 changes: 26 additions & 12 deletions fastdeploy/inter_communicator/engine_worker_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,24 +485,38 @@ def _connect_with_retry(self, max_retries: int = 5, interval: int = 3) -> None:
@staticmethod
def to_tensor(tasks):
"""
Convert NumPy arrays in multimodal inputs to PaddlePaddle tensors.
Convert NumPy arrays in multimodal inputs to Paddle tensors.

Args:
tasks: List of tasks containing multimodal inputs.
tasks (tuple): ([request], bsz)
"""
if (not envs.FD_ENABLE_MAX_PREFILL) and (not envs.FD_ENABLE_E2W_TENSOR_CONVERT):
return
try:
if envs.FD_ENABLE_MAX_PREFILL:
llm_logger.debug(f"Convert image to tensor, type: {type(tasks)}")
batch_tasks, _ = tasks
for task in batch_tasks:
if not hasattr(task, "multimodal_inputs"):
batch_tasks, _ = tasks
for task in batch_tasks:
multimodal_inputs = getattr(task, "multimodal_inputs", None)
if not multimodal_inputs:
continue
# tensor keys
tensor_keys = [
"images",
"patch_idx",
"token_type_ids",
"position_ids",
"attention_mask_offset",
]

llm_logger.debug(f"Converting multimodal inputs to tensor...{tensor_keys}")

for key in tensor_keys:
value = multimodal_inputs.get(key)
if value is None:
continue
images = task.multimodal_inputs["images"]
if isinstance(images, np.ndarray):
llm_logger.debug(f"Convert image to tensor, shape: {images.shape}")
task.multimodal_inputs["images"] = paddle.to_tensor(images)
if not isinstance(value, paddle.Tensor):
multimodal_inputs[key] = paddle.to_tensor(value)
except Exception as e:
llm_logger.warning(f"Failed to convert to tensor: {e}")
llm_logger.warning(f"Tensor conversion failed: {type(e).__name__}: {e}")

@staticmethod
def to_numpy(tasks):
Expand Down
18 changes: 15 additions & 3 deletions fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,22 @@ def init_ep(self, layer: nn.Layer) -> None:

# For non-mixed ep
phase = config.model_config.moe_phase.phase
if phase == "prefill":
self.ep_prefill_runner = self.EPPrefillRunner(**common_args)
if current_platform.is_cuda():
if phase == "prefill":
self.ep_prefill_runner = self.EPPrefillRunner(
**common_args,
use_internode_ll_two_stage=layer.fd_config.parallel_config.use_internode_ll_two_stage,
)
else:
self.ep_decoder_runner = self.EPDecoderRunner(
**common_args,
use_internode_ll_two_stage=layer.fd_config.parallel_config.use_internode_ll_two_stage,
)
else:
self.ep_decoder_runner = self.EPDecoderRunner(**common_args)
if phase == "prefill":
self.ep_prefill_runner = self.EPPrefillRunner(**common_args)
else:
self.ep_decoder_runner = self.EPDecoderRunner(**common_args)

def process_loaded_weights(self, layer, weights) -> None:
"""
Expand Down
7 changes: 6 additions & 1 deletion fastdeploy/splitwise/splitwise_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,13 @@ def create_connection(self, port):
Parameters:
port (int): Port number.
"""
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
address = ("0.0.0.0", int(port))
else:
address = f"/dev/shm/fd_task_queue_{port}.sock"

self.connect_innode_instances[port] = EngineWorkerQueue(
address=("0.0.0.0", int(port)),
address=address,
num_client=self.cfg.parallel_config.tensor_parallel_size,
client_id=0,
)
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,10 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None:
self.prompt_logprobs_reqs[request.request_id] = request
has_prefill_task = True
if (
self.fd_config.scheduler_config.splitwise_role == "decode"
): # In PD, we continue to decode after P generate first token
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has_prefill_task = False?

Copy link
Collaborator Author

@rainyfly rainyfly Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gpu_model_runner.py 里这个变量只用于控制当前 step 下的 need_not_stop,不用改动

self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
elif request.task_type.value == RequestType.DECODE.value: # decode task
logger.debug(f"Handle decode request {request} at idx {idx}")
encoder_block_num = len(request.block_tables)
Expand Down
11 changes: 7 additions & 4 deletions fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,13 @@ def init_device(self) -> None:

def start_task_queue_service(self):
# Initialize task queue
task_address = (
self.parallel_config.pod_ip,
self.parallel_config.engine_worker_queue_port,
)
if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM:
task_address = (
self.parallel_config.pod_ip,
self.parallel_config.engine_worker_queue_port,
)
else:
task_address = f"/dev/shm/fd_task_queue_{self.parallel_config.engine_worker_queue_port}.sock"
logger.info(f"connect task queue address {task_address}")
self.task_queue = TaskQueue(
address=task_address,
Expand Down
Loading